diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 475a41c071..04898e916a 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -61,12 +61,14 @@ async def call(session_id, pbuf_msg, *response_types): class FailureError(Exception): - def __init__(self, code, message): - super(FailureError, self).__init__(code, message) - def to_protobuf(self): from trezor.messages.Failure import Failure - return Failure(code=self.args[0], message=self.args[1]) + code, message = self.args + return Failure(code=code, message=message) + + +class CloseWorkflow(Exception): + pass def protobuf_workflow(session_id, msg_type, data_len, callback, *args): @@ -83,6 +85,9 @@ async def _wrap_protobuf_workflow(wf, session_id): try: result = await wf + except CloseWorkflow: + return + except FailureError as e: await write(session_id, e.to_protobuf()) raise @@ -115,15 +120,14 @@ def _handle_response(session_id, msg_type, data_len, response_types, signal): if msg_type in response_types: return _build_protobuf(msg_type, signal.send) else: - from trezor.messages.FailureType import UnexpectedMessage - signal.send(FailureError(UnexpectedMessage, 'Unexpected message')) + signal.send(CloseWorkflow()) return _handle_workflow(session_id, msg_type, data_len) def _handle_workflow(session_id, msg_type, data_len): if msg_type in _workflow_callbacks: - args = _workflow_args[msg_type] callback = _workflow_callbacks[msg_type] + args = _workflow_args[msg_type] return callback(session_id, msg_type, data_len, *args) else: return _handle_unexpected(session_id, msg_type, data_len) @@ -133,12 +137,20 @@ def _handle_unexpected(session_id, msg_type, data_len): log.warning( __name__, 'session %x: skip type %d, len %d', session_id, msg_type, data_len) + # read the message in full try: while True: yield except EOFError: pass + # respond with an unknown message error + from trezor.messages.Failure import Failure + from trezor.messages.FailureType import UnexpectedMessage + failure = Failure(code=UnexpectedMessage, message='Unexpected message') + sessions.get_codec(session_id).encode( + session_id, failure.MESSAGE_WIRE_TYPE, failure.dumps(), _write_report) + def _write_report(report): if __debug__: