|
|
|
@ -11,6 +11,10 @@ from .wire_dispatcher import dispatch_reports_by_session
|
|
|
|
|
from .wire_codec import \
|
|
|
|
|
decode_wire_stream, encode_wire_message, \
|
|
|
|
|
encode_session_open_message, encode_session_close_message
|
|
|
|
|
from .wire_codec_v1 import \
|
|
|
|
|
SESSION_V1, \
|
|
|
|
|
decode_wire_v1_stream, \
|
|
|
|
|
encode_wire_v1_message
|
|
|
|
|
|
|
|
|
|
_session_handlers = {} # session id -> generator
|
|
|
|
|
_workflow_genfuncs = {} # wire type -> (generator function, args)
|
|
|
|
@ -28,8 +32,9 @@ def generate_session_id():
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def open_session():
|
|
|
|
|
session_id = generate_session_id()
|
|
|
|
|
def open_session(session_id=None):
|
|
|
|
|
if session_id is None:
|
|
|
|
|
session_id = generate_session_id()
|
|
|
|
|
_opened_sessions.add(session_id)
|
|
|
|
|
log.info(__name__, 'opened session %d: %s', session_id, _opened_sessions)
|
|
|
|
|
return session_id
|
|
|
|
@ -66,13 +71,17 @@ def setup():
|
|
|
|
|
session_dispatcher.send(None)
|
|
|
|
|
schedule_task(read_report_stream(session_dispatcher))
|
|
|
|
|
|
|
|
|
|
v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1)
|
|
|
|
|
v1_handler.send(None)
|
|
|
|
|
open_session(SESSION_V1)
|
|
|
|
|
register_session(SESSION_V1, v1_handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def read_message(session_id, *exp_types):
|
|
|
|
|
log.info(__name__, 'reading message of types %s', exp_types)
|
|
|
|
|
future = Future()
|
|
|
|
|
wire_decoder = decode_wire_stream(
|
|
|
|
|
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
|
|
|
|
wire_decoder.send(None)
|
|
|
|
|
register_session(session_id, wire_decoder)
|
|
|
|
|
return await future
|
|
|
|
|
|
|
|
|
@ -81,7 +90,11 @@ async def write_message(session_id, pbuf_message):
|
|
|
|
|
log.info(__name__, 'writing message %s', pbuf_message)
|
|
|
|
|
msg_data = await pbuf_message.dumps()
|
|
|
|
|
msg_type = pbuf_message.message_type.wire_type
|
|
|
|
|
encode_wire_message(msg_type, msg_data, session_id, report_writer)
|
|
|
|
|
|
|
|
|
|
if session_id == SESSION_V1:
|
|
|
|
|
encode_wire_v1_message(msg_type, msg_data, report_writer)
|
|
|
|
|
else:
|
|
|
|
|
encode_wire_message(msg_type, msg_data, session_id, report_writer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def reply_message(session_id, pbuf_message, *exp_types):
|
|
|
|
@ -123,8 +136,12 @@ async def monitor_workflow(workflow, session_id):
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
if session_id in _opened_sessions:
|
|
|
|
|
wire_decoder = decode_wire_stream(
|
|
|
|
|
_handle_registered_type, session_id)
|
|
|
|
|
if session_id == SESSION_V1:
|
|
|
|
|
wire_decoder = decode_wire_v1_stream(_handle_registered_type,
|
|
|
|
|
SESSION_V1)
|
|
|
|
|
else:
|
|
|
|
|
wire_decoder = decode_wire_stream(
|
|
|
|
|
_handle_registered_type, session_id)
|
|
|
|
|
wire_decoder.send(None)
|
|
|
|
|
register_session(session_id, wire_decoder)
|
|
|
|
|
|
|
|
|
|