diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 9fec18ab8..a5fb53c97 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -1,109 +1,62 @@ -from protobuf import build_protobuf_message +import ubinascii +import protobuf -from trezor.loop import schedule_task, Future -from trezor.crypto import random -from trezor.messages import get_protobuf_type -from trezor.workflow import start_workflow from trezor import log +from trezor import loop +from trezor import messages +from trezor import msg +from trezor import workflow -from .io import read_report_stream, write_report_stream -from .dispatcher import dispatch_reports_by_session -from .codec import \ - decode_wire_stream, encode_wire_message, \ - encode_session_open_message, encode_session_close_message -from .codec_v1 import \ - SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message +from . import codec_v1 +from . import codec_v2 +from . import sessions -_session_handlers = {} # session id -> generator -_workflow_genfuncs = {} # wire type -> (generator function, args) -_opened_sessions = set() # session ids +_interface = None -# TODO: get rid of this, use callbacks instead -report_writer = write_report_stream() -report_writer.send(None) +_workflow_callbacks = {} # wire type -> function returning workflow +_workflow_args = {} # wire type -> args -def generate_session_id(): - while True: - session_id = random.uniform(0xffffffff) + 1 - if session_id not in _opened_sessions: - return session_id +def register(wire_type, callback, *args): + if wire_type in _workflow_callbacks: + raise KeyError('Message %d already registered' % wire_type) + _workflow_callbacks[wire_type] = callback + _workflow_args[wire_type] = args -def open_session(session_id=None): - if session_id is None: - session_id = generate_session_id() - _opened_sessions.add(session_id) - log.info(__name__, 'session %d: open', session_id) - return session_id +def setup(iface): + global _interface + + # setup wire interface for reading and writing + _interface = iface + + # implicitly register v1 codec on its session. v2 sessions are + # opened/closed explicitely through session control messages. + _session_open(codec_v1.SESSION) + + # run session dispatcher + loop.schedule_task(_dispatch_reports()) -def close_session(session_id): - _opened_sessions.discard(session_id) - _session_handlers.pop(session_id, None) - log.info(__name__, 'session %d: close', session_id) - - -def register_type(wire_type, genfunc, *args): - if wire_type in _workflow_genfuncs: - raise KeyError('message of type %d already registered' % wire_type) - log.info(__name__, 'register type %d', wire_type) - _workflow_genfuncs[wire_type] = (genfunc, args) - - -def register_session(session_id, handler): - if session_id not in _opened_sessions: - raise KeyError('session %d is unknown' % session_id) - if session_id in _session_handlers: - raise KeyError('session %d is already being listened on' % session_id) - log.info(__name__, 'session %d: listening', session_id) - _session_handlers[session_id] = handler - - -def setup(): - session_dispatcher = dispatch_reports_by_session( - _session_handlers, - _handle_open_session, - _handle_close_session, - _handle_unknown_session) - session_dispatcher.send(None) - schedule_task(read_report_stream(session_dispatcher)) - - v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1) - v1_handler.send(None) - open_session(SESSION_V1) - register_session(SESSION_V1, v1_handler) - - -async def read_message(session_id, *exp_types): - log.info(__name__, 'session %d: read types %s', session_id, exp_types) - signal = Signal() - if session_id == SESSION_V1: - wire_decoder = decode_wire_v1_stream( - _dispatch_and_build_protobuf, session_id, exp_types, signal) - else: - wire_decoder = decode_wire_stream( - _dispatch_and_build_protobuf, session_id, exp_types, signal) - wire_decoder.send(None) - register_session(session_id, wire_decoder) +async def read(session_id, *wire_types): + log.info(__name__, 'session %d: read types %s', session_id, wire_types) + signal = loop.Signal() + sessions.listen(session_id, _handle_response, wire_types, signal) return await signal -async def write_message(session_id, pbuf_message): - log.info(__name__, 'session %d: write %s', session_id, pbuf_message) - pbuf_type = pbuf_message.__class__ - msg_data = pbuf_type.dumps(pbuf_message) +async def write(session_id, pbuf_msg): + log.info(__name__, 'session %d: write %s', session_id, pbuf_msg) + pbuf_type = pbuf_msg.__class__ + msg_data = pbuf_type.dumps(pbuf_msg) msg_type = pbuf_type.MESSAGE_WIRE_TYPE - - if session_id == SESSION_V1: - encode_wire_v1_message(msg_type, msg_data, report_writer) - else: - encode_wire_message(msg_type, msg_data, session_id, report_writer) + sessions.get_codec(session_id).encode( + session_id, msg_type, msg_data, _write_report) -async def reply_message(session_id, pbuf_message, *exp_types): - await write_message(session_id, pbuf_message) - return await read_message(session_id, *exp_types) +async def call(session_id, pbuf_msg, *response_types): + await write(session_id, pbuf_msg) + return await read(session_id, *response_types) class FailureError(Exception): @@ -113,94 +66,107 @@ class FailureError(Exception): def to_protobuf(self): from trezor.messages.Failure import Failure - return Failure(code=self.args[0], - message=self.args[1]) + return Failure(code=self.args[0], message=self.args[1]) -async def monitor_workflow(workflow, session_id): +def protobuf_workflow(session_id, msg_type, data_len, callback, *args): + return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args) + + +def _start_protobuf_workflow(pbuf_msg, session_id, callback, args): + wf = callback(session_id, pbuf_msg, *args) + wf = _wrap_protobuf_workflow(wf, session_id) + workflow.start(wf) + + +async def _wrap_protobuf_workflow(wf, session_id): try: - result = await workflow + result = await wf except FailureError as e: - await write_message(session_id, e.to_protobuf()) + await write(session_id, e.to_protobuf()) raise except Exception as e: from trezor.messages.Failure import Failure from trezor.messages.FailureType import FirmwareError - await write_message(session_id, - Failure(code=FirmwareError, - message='Firmware Error')) + await write(session_id, Failure( + code=FirmwareError, message='Firmware Error')) raise else: if result is not None: - await write_message(session_id, result) + await write(session_id, result) return result finally: - if session_id in _opened_sessions: - if session_id == SESSION_V1: - wire_decoder = decode_wire_v1_stream( - _handle_registered_type, session_id) - else: - wire_decoder = decode_wire_stream( - _handle_registered_type, session_id) - wire_decoder.send(None) - register_session(session_id, wire_decoder) + if session_id in sessions.opened: + sessions.listen(session_id, _handle_workflow) -def protobuf_handler(msg_type, data_len, session_id, callback, *args): - def finalizer(message): - workflow = callback(message, session_id, *args) - monitored = monitor_workflow(workflow, session_id) - start_workflow(monitored) - pbuf_type = get_protobuf_type(msg_type) - builder = build_protobuf_message(pbuf_type, finalizer) +def _build_protobuf(msg_type, callback, *args): + pbuf_type = messages.get_protobuf_type(msg_type) + builder = protobuf.build_message(pbuf_type, callback, *args) builder.send(None) return pbuf_type.load(target=builder) -def _handle_open_session(): - session_id = open_session() - wire_decoder = decode_wire_stream(_handle_registered_type, session_id) - wire_decoder.send(None) - register_session(session_id, wire_decoder) - encode_session_open_message(session_id, report_writer) - - -def _handle_close_session(session_id): - close_session(session_id) - encode_session_close_message(session_id, report_writer) - - -def _handle_unknown_session(session_id, report_data): - pass # TODO - - -def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, signal): - if msg_type in exp_types: - pbuf_type = get_protobuf_type(msg_type) - builder = build_protobuf_message(pbuf_type, signal.send) - builder.send(None) - return pbuf_type.load(target=builder) +def _handle_response(session_id, msg_type, data_len, response_types, signal): + if msg_type in response_types: + return _build_protobuf(msg_type, signal.send) else: from trezor.messages.FailureType import UnexpectedMessage signal.send(FailureError(UnexpectedMessage, 'Unexpected message')) - return _handle_registered_type(msg_type, data_len, session_id) + return _handle_workflow(session_id, msg_type, data_len) -def _handle_registered_type(msg_type, data_len, session_id): - fallback = (_handle_unexpected_type, ()) - genfunc, args = _workflow_genfuncs.get(msg_type, fallback) - return genfunc(msg_type, data_len, session_id, *args) +def _handle_workflow(session_id, msg_type, data_len): + if msg_type in _workflow_callbacks: + args = _workflow_args[msg_type] + callback = _workflow_callbacks[msg_type] + return callback(session_id, msg_type, data_len, *args) + else: + return _handle_unexpected(session_id, msg_type, data_len) -def _handle_unexpected_type(msg_type, data_len, session_id): - log.warning(__name__, 'session %d: skip type %d, len %d', - session_id, msg_type, data_len) +def _handle_unexpected(session_id, msg_type, data_len): + log.warning( + __name__, 'session %d: skip type %d, len %d', session_id, msg_type, data_len) try: while True: yield except EOFError: pass + + +def _write_report(report): + if __debug__: + log.info(__name__, 'write report %s', ubinascii.hexlify(report)) + msg.send(_interface, report) + + +def _dispatch_reports(): + while True: + report, = yield loop.Select(_interface) + report = memoryview(report) + if __debug__: + log.debug(__name__, 'read report %s', ubinascii.hexlify(report)) + sessions.dispatch( + report, _session_open, _session_close, _session_unknown) + + +def _session_open(session_id=None): + session_id = sessions.open(session_id) + sessions.listen(session_id, _handle_workflow) + sessions.get_codec(session_id).encode_session_open( + session_id, _write_report) + + +def _session_close(session_id): + sessions.close(session_id) + sessions.get_codec(session_id).encode_session_close( + session_id, _write_report) + + +def _session_unknown(session_id, report_data): + pass diff --git a/src/trezor/wire/codec_v1.py b/src/trezor/wire/codec_v1.py index 8e6b0ac6b..55684ffd0 100644 --- a/src/trezor/wire/codec_v1.py +++ b/src/trezor/wire/codec_v1.py @@ -1,50 +1,50 @@ from micropython import const + import ustruct -SESSION_V1 = const(0) -REP_MARKER_V1 = const(63) # ord('?') -REP_MARKER_V1_LEN = const(1) # len('?') +SESSION = const(0) +REP_MARKER = const(63) # ord('?') +REP_MARKER_LEN = const(1) # len('?') _REP_LEN = const(64) _MSG_HEADER_MAGIC = const(35) # org('#') -_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length -_MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1) +_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length +_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) -def detect_v1(data): - return (data[0] == REP_MARKER_V1) +def detect(data): + return data[0] == REP_MARKER -def parse_report_v1(data): +def parse_report(data): if len(data) != _REP_LEN: raise ValueError('Invalid buffer size') - return None, SESSION_V1, data[1:] + return None, SESSION, data[1:] def parse_message(data): - magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data) + magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: raise ValueError('Corrupted magic bytes') - return msg_type, data_len, data[_MSG_HEADER_V1_LEN:] + return msg_type, data_len, data[_MSG_HEADER_LEN:] def serialize_message_header(data, msg_type, msg_len): - if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN: + if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN: raise ValueError('Invalid buffer size') if msg_type < 0 or msg_type > 65535: raise ValueError('Value is out of range') ustruct.pack_into( - _MSG_HEADER_V1, data, REP_MARKER_V1_LEN, + _MSG_HEADER, data, REP_MARKER_LEN, _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) -def decode_wire_v1_stream(genfunc, session_id, *args): +def decode_stream(session_id, callback, *args): '''Decode a v1 wire message from the report data and stream it to target. Receives report payloads. After first report, creates target by calling -`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message -data. -Throws `EOFError` to target after last data chunk. +`callback(session_id, msg_type, data_len, *args)` and sends chunks of message +data. Throws `EOFError` to target after last data chunk. Pass report payloads as `memoryview` for cheaper slicing. ''' @@ -52,7 +52,7 @@ Pass report payloads as `memoryview` for cheaper slicing. message = yield # read first report msg_type, data_len, data = parse_message(message) - target = genfunc(msg_type, data_len, session_id, *args) + target = callback(session_id, msg_type, data_len, *args) target.send(None) while data_len > 0: @@ -68,18 +68,18 @@ Pass report payloads as `memoryview` for cheaper slicing. target.throw(EOFError()) -def encode_wire_v1_message(msg_type, msg_data, target): - '''Encode a full v1 wire message directly to reports and stream it to target. +def encode(session_id, msg_type, msg_data, callback): + '''Encode a full v1 wire message directly to reports and stream it to callback. -Target receives `memoryview`s of HID reports which are valid until the targets -`send()` method returns. - ''' +Callback receives `memoryview`s of HID reports which are valid until the +callback returns. +''' report = memoryview(bytearray(_REP_LEN)) - report[0] = REP_MARKER_V1 + report[0] = REP_MARKER serialize_message_header(report, msg_type, len(msg_data)) source_data = memoryview(msg_data) - target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:] + target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:] while True: # move as much as possible from source to target @@ -95,10 +95,20 @@ Target receives `memoryview`s of HID reports which are valid until the targets target_data[x] = 0 x += 1 - target.send(report) + callback(report) if not source_data: break # reset to skip the magic, not the whole header anymore - target_data = report[REP_MARKER_V1_LEN:] + target_data = report[REP_MARKER_LEN:] + + +def encode_session_open(session_id, callback): + # v1 codec does not have explicit session support + pass + + +def encode_session_close(session_id, callback): + # v1 codec does not have explicit session support + pass \ No newline at end of file diff --git a/src/trezor/wire/codec.py b/src/trezor/wire/codec_v2.py similarity index 90% rename from src/trezor/wire/codec.py rename to src/trezor/wire/codec_v2.py index bde3539f1..63f7fd90c 100644 --- a/src/trezor/wire/codec.py +++ b/src/trezor/wire/codec_v2.py @@ -81,11 +81,11 @@ class MessageChecksumError(Exception): pass -def decode_wire_stream(genfunc, session_id, *args): +def decode_stream(session_id, callback, *args): '''Decode a wire message from the report data and stream it to target. Receives report payloads. After first report, creates target by calling -`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message +`callback(session_id, msg_type, data_len, *args)` and sends chunks of message data. Throws `EOFError` to target after last data chunk, in case of valid checksum. Throws `MessageChecksumError` to target if data doesn't match the checksum. @@ -95,7 +95,7 @@ Pass report payloads as `memoryview` for cheaper slicing. message = yield # read first report msg_type, data_len, data_tail = parse_message(message) - target = genfunc(msg_type, data_len, session_id, *args) + target = callback(session_id, msg_type, data_len, *args) target.send(None) checksum = 0 # crc32 @@ -126,11 +126,11 @@ Pass report payloads as `memoryview` for cheaper slicing. target.throw(EOFError()) -def encode_wire_message(msg_type, msg_data, session_id, target): - '''Encode a full wire message directly to reports and stream it to target. +def encode(session_id, msg_type, msg_data, callback): + '''Encode a full wire message directly to reports and stream it to callback. -Target receives `memoryview`s of HID reports which are valid until the targets -`send()` method returns. +Callback receives `memoryview`s of HID reports which are valid until the +callback returns. ''' report = memoryview(bytearray(_REP_LEN)) serialize_report_header(report, REP_MARKER_HEADER, session_id) @@ -166,7 +166,7 @@ Target receives `memoryview`s of HID reports which are valid until the targets target_data[x] = 0 x += 1 - target.send(report) + callback(report) if not source_data and not msg_footer: break @@ -178,13 +178,13 @@ Target receives `memoryview`s of HID reports which are valid until the targets target_data = report[_REP_HEADER_LEN:] -def encode_session_open_message(session_id, target): +def encode_session_open(session_id, callback): report = bytearray(_REP_LEN) serialize_report_header(report, REP_MARKER_OPEN, session_id) - target.send(report) + callback(report) -def encode_session_close_message(session_id, target): +def encode_session_close(session_id, callback): report = bytearray(_REP_LEN) serialize_report_header(report, REP_MARKER_CLOSE, session_id) - target.send(report) + callback(report) diff --git a/src/trezor/wire/dispatcher.py b/src/trezor/wire/dispatcher.py deleted file mode 100644 index 55eae3d42..000000000 --- a/src/trezor/wire/dispatcher.py +++ /dev/null @@ -1,46 +0,0 @@ -from trezor import log -from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE -from .codec_v1 import detect_v1, parse_report_v1 - - -def dispatch_reports_by_session(handlers, - open_callback, - close_callback, - unknown_callback): - ''' - Consumes reports adhering to the wire codec and dispatches the report - payloads by between the passed handlers. - ''' - - while True: - - data = (yield) - if detect_v1(data): - marker, session_id, report_data = parse_report_v1(data) - else: - marker, session_id, report_data = parse_report(data) - - if marker == REP_MARKER_OPEN: - log.debug(__name__, 'request for new session') - open_callback() - continue - - elif marker == REP_MARKER_CLOSE: - log.debug(__name__, 'request for closing session %d', session_id) - close_callback(session_id) - continue - - elif session_id not in handlers: - log.debug(__name__, 'report on unknown session %d', session_id) - unknown_callback(session_id, report_data) - continue - - log.debug(__name__, 'report on session %d', session_id) - handler = handlers[session_id] - - try: - handler.send(report_data) - except StopIteration: - handlers.pop(session_id) - except Exception as e: - log.exception(__name__, e) diff --git a/src/trezor/wire/io.py b/src/trezor/wire/io.py deleted file mode 100644 index e148ca5fb..000000000 --- a/src/trezor/wire/io.py +++ /dev/null @@ -1,19 +0,0 @@ -from micropython import const -from ubinascii import hexlify -from trezor import msg, loop, log - -_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface - - -def read_report_stream(target, iface=_DEFAULT_IFACE): - while True: - report, = yield loop.Select(iface) - log.debug(__name__, 'read report %s', hexlify(report)) - target.send(memoryview(report)) - - -def write_report_stream(iface=_DEFAULT_IFACE): - while True: - report = yield - log.info(__name__, 'write report %s', hexlify(report)) - msg.send(iface, report) diff --git a/src/trezor/wire/sessions.py b/src/trezor/wire/sessions.py new file mode 100644 index 000000000..aea4f7d56 --- /dev/null +++ b/src/trezor/wire/sessions.py @@ -0,0 +1,82 @@ +from trezor import log +from trezor.crypto import random + +from . import codec_v1 +from . import codec_v2 + +opened = set() # opened session ids +readers = {} # session id -> generator + + +def generate(): + while True: + session_id = random.uniform(0xffffffff) + 1 + if session_id not in opened: + return session_id + + +def open(session_id=None): + if session_id is None: + session_id = generate() + log.info(__name__, 'session %d: open', session_id) + opened.add(session_id) + return session_id + + +def close(session_id): + log.info(__name__, 'session %d: close', session_id) + opened.discard(session_id) + readers.pop(session_id, None) + + +def get_codec(session_id): + if session_id == codec_v1.SESSION: + return codec_v1 + else: + return codec_v2 + + +def listen(session_id, handler, *args): + if session_id not in opened: + raise KeyError('Session %d is unknown' % session_id) + if session_id in readers: + raise KeyError('Session %d is already being listened on' % session_id) + log.info(__name__, 'session %d: listening', session_id) + decoder = get_codec(session_id).decode_stream(session_id, handler, *args) + decoder.send(None) + readers[session_id] = decoder + + +def dispatch(report, open_callback, close_callback, unknown_callback): + ''' + Dispatches payloads of reports adhering to one of the wire codecs. + ''' + + if codec_v1.detect(report): + marker, session_id, report_data = codec_v1.parse_report(report) + else: + marker, session_id, report_data = codec_v2.parse_report(report) + + if marker == codec_v2.REP_MARKER_OPEN: + log.debug(__name__, 'request for new session') + open_callback() + return + elif marker == codec_v2.REP_MARKER_CLOSE: + log.debug(__name__, 'request for closing session %d', session_id) + close_callback(session_id) + return + + if session_id not in readers: + log.warning(__name__, 'report on unknown session %d', session_id) + unknown_callback(session_id, report_data) + return + + log.debug(__name__, 'report on session %d', session_id) + reader = readers[session_id] + + try: + reader.send(report_data) + except StopIteration: + readers.pop(session_id) + except Exception as e: + log.exception(__name__, e) diff --git a/tests/test_apps.wallet.signtx.py b/tests/test_apps.wallet.signtx.py new file mode 100644 index 000000000..2e395ed4b --- /dev/null +++ b/tests/test_apps.wallet.signtx.py @@ -0,0 +1,111 @@ +from common import * + +from trezor.utils import chunks +from trezor.crypto import bip32, bip39 +from trezor.messages.SignTx import SignTx +from trezor.messages.TxInputType import TxInputType +from trezor.messages.TxOutputType import TxOutputType +from trezor.messages.TxOutputBinType import TxOutputBinType +from trezor.messages.TxRequest import TxRequest +from trezor.messages.TxAck import TxAck +from trezor.messages.TransactionType import TransactionType +from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED +from trezor.messages.TxRequestDetailsType import TxRequestDetailsType +from trezor.messages.TxRequestSerializedType import TxRequestSerializedType +from trezor.messages import OutputScriptType + +from apps.common import coins +from apps.wallet.sign_tx import signing + +class TestSignTx(unittest.TestCase): + # pylint: disable=C0301 + + def test_one_one_fee(self): + # tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 + # input 0: 0.0039 BTC + + coin_bitcoin = coins.by_name('Bitcoin') + + ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1) + pinp1 = TxInputType(script_sig=unhexlify('483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'), + prev_hash=unhexlify('c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'), + prev_index=1, + script_type=None, + sequence=None) + pinp2 = TxInputType(script_sig=unhexlify('48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'), + prev_hash=unhexlify('1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'), + prev_index=1, + script_type=None, + sequence=None) + pout1 = TxOutputBinType(script_pubkey=unhexlify('76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'), + amount=390000, + address_n=None) + + inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e + # amount=390000, + prev_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'), + prev_index=0, + script_type=None, + sequence=None) + out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1', + amount=390000 - 10000, + script_type=OutputScriptType.PAYTOADDRESS, + address_n=None) + tx = SignTx(coin_name=None, version=None, lock_time=None, inputs_count=1, outputs_count=1) + + messages = [ + None, + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), + TxAck(tx=TransactionType(inputs=[inp1])), + TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), + TxAck(tx=ptx1), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), + TxAck(tx=TransactionType(inputs=[pinp1])), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), + TxAck(tx=TransactionType(inputs=[pinp2])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None), + TxAck(tx=TransactionType(bin_outputs=[pout1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(outputs=[out1])), + signing.UiConfirmOutput(out1, coin_bitcoin), + True, + signing.UiConfirmTotal(380000, 10000, coin_bitcoin), + True, + # ButtonRequest(code=ButtonRequest_ConfirmOutput), + # ButtonRequest(code=ButtonRequest_SignTx), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(inputs=[inp1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(outputs=[out1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType( + signature_index=0, + signature=unhexlify('30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'), + serialized_tx=unhexlify('010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))), + TxAck(tx=TransactionType(outputs=[out1])), + TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType( + signature_index=None, + signature=None, + serialized_tx=unhexlify('0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'), + )), + ] + + seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') + root = bip32.from_seed(seed, 'secp256k1') + + signer = signing.sign_tx(tx, root) + for request, response in chunks(messages, 2): + self.assertEqualEx(signer.send(request), response) + with self.assertRaises(StopIteration): + signer.send(None) + + def assertEqualEx(self, a, b): + # hack to avoid adding __eq__ to signing.Ui* classes + if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or + (isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))): + return self.assertEqual(a.__dict__, b.__dict__) + else: + return self.assertEqual(a, b) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_trezor.wire.codec_v1.py b/tests/test_trezor.wire.codec_v1.py index 73b377b39..4638f1443 100644 --- a/tests/test_trezor.wire.codec_v1.py +++ b/tests/test_trezor.wire.codec_v1.py @@ -13,16 +13,16 @@ class TestWireCodecV1(unittest.TestCase): def test_detect(self): for i in range(0, 256): if i == ord(b'?'): - self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63)) else: - self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63)) def test_parse(self): d = bytes(range(0, 55)) m = b'##\x00\x00\x00\x00\x00\x37' + d r = b'?' + m - rm, rs, rd = codec_v1.parse_report_v1(r) + rm, rs, rd = codec_v1.parse_report(r) self.assertEqual(rm, None) self.assertEqual(rs, 0) self.assertEqual(rd, m) @@ -35,7 +35,7 @@ class TestWireCodecV1(unittest.TestCase): for i in range(0, 1024): if i != 64: with self.assertRaises(ValueError): - codec_v1.parse_report_v1(bytes(range(0, i))) + codec_v1.parse_report(bytes(range(0, i))) for hx in range(0, 256): for hy in range(0, 256): @@ -61,8 +61,8 @@ class TestWireCodecV1(unittest.TestCase): message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 record = [] - genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy') - decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy') + decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) try: @@ -78,8 +78,8 @@ class TestWireCodecV1(unittest.TestCase): message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data record = [] - genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy') - decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy') + decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) try: @@ -103,8 +103,8 @@ class TestWireCodecV1(unittest.TestCase): message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))] record = [] - genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') - decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') + decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) res = 1 @@ -124,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase): target = self._record(record)() target.send(None) - codec_v1.encode_wire_v1_message(0xabcd, b'', target) + codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send) self.assertEqual(len(record), 1) self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) @@ -135,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase): target = self._record(record)() target.send(None) - codec_v1.encode_wire_v1_message(0xabcd, data, target) + codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send) self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) def test_encode_generated_range(self): @@ -158,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase): target = genfunc() target.send(None) - codec_v1.encode_wire_v1_message(msg_type, data, target) + codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send) self.assertEqual(received, len(reports)) def _record(self, record, *_args): diff --git a/tests/test_trezor.wire.codec.py b/tests/test_trezor.wire.codec_v2.py similarity index 76% rename from tests/test_trezor.wire.codec.py rename to tests/test_trezor.wire.codec_v2.py index 00599ebcf..60226b4ed 100644 --- a/tests/test_trezor.wire.codec.py +++ b/tests/test_trezor.wire.codec_v2.py @@ -6,7 +6,7 @@ import ubinascii from trezor.crypto import random from trezor.utils import chunks -from trezor.wire import codec +from trezor.wire import codec_v2 class TestWireCodec(unittest.TestCase): # pylint: disable=C0301 @@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase): def test_parse(self): d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) - m, s, d = codec.parse_report(d) + m, s, d = codec_v2.parse_report(d) self.assertEqual(m, b'O'[0]) self.assertEqual(s, 0x01234567) self.assertEqual(d, bytes(range(0, 59))) - t, l, d = codec.parse_message(d) + t, l, d = codec_v2.parse_message(d) self.assertEqual(t, 0x00010203) self.assertEqual(l, 0x04050607) self.assertEqual(d, bytes(range(8, 59))) - f, = codec.parse_message_footer(d[0:4]) + f, = codec_v2.parse_message_footer(d[0:4]) self.assertEqual(f, 0x08090a0b) for i in range(0, 1024): if i != 64: with self.assertRaises(ValueError): - codec.parse_report(bytes(range(0, i))) + codec_v2.parse_report(bytes(range(0, i))) if i != 59: with self.assertRaises(ValueError): - codec.parse_message(bytes(range(0, i))) + codec_v2.parse_message(bytes(range(0, i))) if i != 4: with self.assertRaises(ValueError): - codec.parse_message_footer(bytes(range(0, i))) + codec_v2.parse_message_footer(bytes(range(0, i))) def test_serialize(self): data = bytearray(range(0, 6)) - codec.serialize_report_header(data, 0x12, 0x3456789a) + codec_v2.serialize_report_header(data, 0x12, 0x3456789a) self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') data = bytearray(range(0, 6)) - codec.serialize_opened_session(data, 0x3456789a) - self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') + codec_v2.serialize_opened_session(data, 0x3456789a) + self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') data = bytearray(range(0, 14)) - codec.serialize_message_header(data, 0x01234567, 0x89abcdef) + codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef) self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') data = bytearray(range(0, 5)) - codec.serialize_message_footer(data, 0x89abcdef) + codec_v2.serialize_message_footer(data, 0x89abcdef) self.assertEqual(data, b'\x89\xab\xcd\xef\x04') for i in range(0, 13): data = bytearray(i) if i < 4: with self.assertRaises(ValueError): - codec.serialize_message_footer(data, 0x00) + codec_v2.serialize_message_footer(data, 0x00) if i < 5: with self.assertRaises(ValueError): - codec.serialize_report_header(data, 0x00, 0x00) + codec_v2.serialize_report_header(data, 0x00, 0x00) with self.assertRaises(ValueError): - codec.serialize_opened_session(data, 0x00) + codec_v2.serialize_opened_session(data, 0x00) with self.assertRaises(ValueError): - codec.serialize_message_header(data, 0x00, 0x00) + codec_v2.serialize_message_header(data, 0x00, 0x00) def test_decode_empty(self): message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 record = [] - genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy') - decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy') + decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) try: @@ -90,8 +90,8 @@ class TestWireCodec(unittest.TestCase): message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer record = [] - genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') - decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') + decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) try: @@ -109,8 +109,8 @@ class TestWireCodec(unittest.TestCase): message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer record = [] - genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') - decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy') + decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) try: @@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase): self.assertEqual(res, None) self.assertEqual(len(record), 2) self.assertEqual(record[0], data) - self.assertIsInstance(record[1], codec.MessageChecksumError) + self.assertIsInstance(record[1], codec_v2.MessageChecksumError) def test_decode_generated_range(self): for data_len in range(1, 512): @@ -136,8 +136,8 @@ class TestWireCodec(unittest.TestCase): message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))] record = [] - genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') - decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy') + decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy') decoder.send(None) res = 1 @@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase): target = self._record(record)() target.send(None) - codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) + codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send) self.assertEqual(len(record), 1) self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) @@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase): target = self._record(record)() target.send(None) - codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) + codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send) self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) def test_encode_generated_range(self): @@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase): target = genfunc() target.send(None) - codec.encode_wire_message(msg_type, data, session_id, target) + codec_v2.encode(session_id, msg_type, data, target.send) self.assertEqual(received, len(reports)) def _record(self, record, *_args):