mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-03 07:29:14 +00:00
207 lines
6.7 KiB
Python
207 lines
6.7 KiB
Python
from protobuf import build_protobuf_message
|
|
|
|
from trezor.loop import schedule_task, Future
|
|
from trezor.crypto import random
|
|
from trezor.messages import get_protobuf_type
|
|
from trezor.workflow import start_workflow
|
|
from trezor import log
|
|
|
|
from .io import read_report_stream, write_report_stream
|
|
from .dispatcher import dispatch_reports_by_session
|
|
from .codec import \
|
|
decode_wire_stream, encode_wire_message, \
|
|
encode_session_open_message, encode_session_close_message
|
|
from .codec_v1 import \
|
|
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
|
|
|
|
_session_handlers = {} # session id -> generator
|
|
_workflow_genfuncs = {} # wire type -> (generator function, args)
|
|
_opened_sessions = set() # session ids
|
|
|
|
# TODO: get rid of this, use callbacks instead
|
|
report_writer = write_report_stream()
|
|
report_writer.send(None)
|
|
|
|
|
|
def generate_session_id():
|
|
while True:
|
|
session_id = random.uniform(0xffffffff) + 1
|
|
if session_id not in _opened_sessions:
|
|
return session_id
|
|
|
|
|
|
def open_session(session_id=None):
|
|
if session_id is None:
|
|
session_id = generate_session_id()
|
|
_opened_sessions.add(session_id)
|
|
log.info(__name__, 'session %d: open', session_id)
|
|
return session_id
|
|
|
|
|
|
def close_session(session_id):
|
|
_opened_sessions.discard(session_id)
|
|
_session_handlers.pop(session_id, None)
|
|
log.info(__name__, 'session %d: close', session_id)
|
|
|
|
|
|
def register_type(wire_type, genfunc, *args):
|
|
if wire_type in _workflow_genfuncs:
|
|
raise KeyError('message of type %d already registered' % wire_type)
|
|
log.info(__name__, 'register type %d', wire_type)
|
|
_workflow_genfuncs[wire_type] = (genfunc, args)
|
|
|
|
|
|
def register_session(session_id, handler):
|
|
if session_id not in _opened_sessions:
|
|
raise KeyError('session %d is unknown' % session_id)
|
|
if session_id in _session_handlers:
|
|
raise KeyError('session %d is already being listened on' % session_id)
|
|
log.info(__name__, 'session %d: listening', session_id)
|
|
_session_handlers[session_id] = handler
|
|
|
|
|
|
def setup():
|
|
session_dispatcher = dispatch_reports_by_session(
|
|
_session_handlers,
|
|
_handle_open_session,
|
|
_handle_close_session,
|
|
_handle_unknown_session)
|
|
session_dispatcher.send(None)
|
|
schedule_task(read_report_stream(session_dispatcher))
|
|
|
|
v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1)
|
|
v1_handler.send(None)
|
|
open_session(SESSION_V1)
|
|
register_session(SESSION_V1, v1_handler)
|
|
|
|
|
|
async def read_message(session_id, *exp_types):
|
|
log.info(__name__, 'session %d: read types %s', session_id, exp_types)
|
|
future = Future()
|
|
if session_id == SESSION_V1:
|
|
wire_decoder = decode_wire_v1_stream(
|
|
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
|
else:
|
|
wire_decoder = decode_wire_stream(
|
|
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
|
wire_decoder.send(None)
|
|
register_session(session_id, wire_decoder)
|
|
return await future
|
|
|
|
|
|
async def write_message(session_id, pbuf_message):
|
|
log.info(__name__, 'session %d: write %s', session_id, pbuf_message)
|
|
pbuf_type = pbuf_message.__class__
|
|
msg_data = pbuf_type.dumps(pbuf_message)
|
|
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
|
|
|
|
if session_id == SESSION_V1:
|
|
encode_wire_v1_message(msg_type, msg_data, report_writer)
|
|
else:
|
|
encode_wire_message(msg_type, msg_data, session_id, report_writer)
|
|
|
|
|
|
async def reply_message(session_id, pbuf_message, *exp_types):
|
|
await write_message(session_id, pbuf_message)
|
|
return await read_message(session_id, *exp_types)
|
|
|
|
|
|
class FailureError(Exception):
|
|
|
|
def __init__(self, code, message):
|
|
super(FailureError, self).__init__(code, message)
|
|
|
|
def to_protobuf(self):
|
|
from trezor.messages.Failure import Failure
|
|
return Failure(code=self.args[0],
|
|
message=self.args[1])
|
|
|
|
|
|
async def monitor_workflow(workflow, session_id):
|
|
try:
|
|
result = await workflow
|
|
|
|
except FailureError as e:
|
|
await write_message(session_id, e.to_protobuf())
|
|
raise
|
|
|
|
except Exception as e:
|
|
from trezor.messages.Failure import Failure
|
|
from trezor.messages.FailureType import FirmwareError
|
|
await write_message(session_id,
|
|
Failure(code=FirmwareError,
|
|
message='Firmware Error'))
|
|
raise
|
|
|
|
else:
|
|
if result is not None:
|
|
await write_message(session_id, result)
|
|
return result
|
|
|
|
finally:
|
|
if session_id in _opened_sessions:
|
|
if session_id == SESSION_V1:
|
|
wire_decoder = decode_wire_v1_stream(
|
|
_handle_registered_type, session_id)
|
|
else:
|
|
wire_decoder = decode_wire_stream(
|
|
_handle_registered_type, session_id)
|
|
wire_decoder.send(None)
|
|
register_session(session_id, wire_decoder)
|
|
|
|
|
|
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
|
|
def finalizer(message):
|
|
workflow = callback(message, session_id, *args)
|
|
monitored = monitor_workflow(workflow, session_id)
|
|
start_workflow(monitored)
|
|
pbuf_type = get_protobuf_type(msg_type)
|
|
builder = build_protobuf_message(pbuf_type, finalizer)
|
|
builder.send(None)
|
|
return pbuf_type.load(target=builder)
|
|
|
|
|
|
def _handle_open_session():
|
|
session_id = open_session()
|
|
wire_decoder = decode_wire_stream(_handle_registered_type, session_id)
|
|
wire_decoder.send(None)
|
|
register_session(session_id, wire_decoder)
|
|
encode_session_open_message(session_id, report_writer)
|
|
|
|
|
|
def _handle_close_session(session_id):
|
|
close_session(session_id)
|
|
encode_session_close_message(session_id, report_writer)
|
|
|
|
|
|
def _handle_unknown_session(session_id, report_data):
|
|
pass # TODO
|
|
|
|
|
|
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future):
|
|
if msg_type in exp_types:
|
|
pbuf_type = get_protobuf_type(msg_type)
|
|
builder = build_protobuf_message(pbuf_type, future.resolve)
|
|
builder.send(None)
|
|
return pbuf_type.load(target=builder)
|
|
else:
|
|
from trezor.messages.FailureType import UnexpectedMessage
|
|
future.resolve(FailureError(UnexpectedMessage, 'Unexpected message'))
|
|
return _handle_registered_type(msg_type, data_len, session_id)
|
|
|
|
|
|
def _handle_registered_type(msg_type, data_len, session_id):
|
|
fallback = (_handle_unexpected_type, ())
|
|
genfunc, args = _workflow_genfuncs.get(msg_type, fallback)
|
|
return genfunc(msg_type, data_len, session_id, *args)
|
|
|
|
|
|
def _handle_unexpected_type(msg_type, data_len, session_id):
|
|
log.warning(__name__, 'session %d: skip type %d, len %d',
|
|
session_id, msg_type, data_len)
|
|
try:
|
|
while True:
|
|
yield
|
|
except EOFError:
|
|
pass
|