1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-14 09:20:55 +00:00

wire: refactoring

- prefer importing modules instead of module members
- session_id is always first argument
- prefer much shorter names, don't expect users to import module members
- shuffle around session-specific code
- reduce allocations
This commit is contained in:
Jan Pochyla 2016-12-08 16:14:47 +01:00
parent 0b7874ad43
commit d56dc88861
9 changed files with 394 additions and 290 deletions

View File

@ -1,109 +1,62 @@
from protobuf import build_protobuf_message import ubinascii
import protobuf
from trezor.loop import schedule_task, Future
from trezor.crypto import random
from trezor.messages import get_protobuf_type
from trezor.workflow import start_workflow
from trezor import log from trezor import log
from trezor import loop
from trezor import messages
from trezor import msg
from trezor import workflow
from .io import read_report_stream, write_report_stream from . import codec_v1
from .dispatcher import dispatch_reports_by_session from . import codec_v2
from .codec import \ from . import sessions
decode_wire_stream, encode_wire_message, \
encode_session_open_message, encode_session_close_message
from .codec_v1 import \
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
_session_handlers = {} # session id -> generator _interface = None
_workflow_genfuncs = {} # wire type -> (generator function, args)
_opened_sessions = set() # session ids
# TODO: get rid of this, use callbacks instead _workflow_callbacks = {} # wire type -> function returning workflow
report_writer = write_report_stream() _workflow_args = {} # wire type -> args
report_writer.send(None)
def generate_session_id(): def register(wire_type, callback, *args):
while True: if wire_type in _workflow_callbacks:
session_id = random.uniform(0xffffffff) + 1 raise KeyError('Message %d already registered' % wire_type)
if session_id not in _opened_sessions: _workflow_callbacks[wire_type] = callback
return session_id _workflow_args[wire_type] = args
def open_session(session_id=None): def setup(iface):
if session_id is None: global _interface
session_id = generate_session_id()
_opened_sessions.add(session_id) # setup wire interface for reading and writing
log.info(__name__, 'session %d: open', session_id) _interface = iface
return session_id
# implicitly register v1 codec on its session. v2 sessions are
# opened/closed explicitely through session control messages.
_session_open(codec_v1.SESSION)
# run session dispatcher
loop.schedule_task(_dispatch_reports())
def close_session(session_id): async def read(session_id, *wire_types):
_opened_sessions.discard(session_id) log.info(__name__, 'session %d: read types %s', session_id, wire_types)
_session_handlers.pop(session_id, None) signal = loop.Signal()
log.info(__name__, 'session %d: close', session_id) sessions.listen(session_id, _handle_response, wire_types, signal)
def register_type(wire_type, genfunc, *args):
if wire_type in _workflow_genfuncs:
raise KeyError('message of type %d already registered' % wire_type)
log.info(__name__, 'register type %d', wire_type)
_workflow_genfuncs[wire_type] = (genfunc, args)
def register_session(session_id, handler):
if session_id not in _opened_sessions:
raise KeyError('session %d is unknown' % session_id)
if session_id in _session_handlers:
raise KeyError('session %d is already being listened on' % session_id)
log.info(__name__, 'session %d: listening', session_id)
_session_handlers[session_id] = handler
def setup():
session_dispatcher = dispatch_reports_by_session(
_session_handlers,
_handle_open_session,
_handle_close_session,
_handle_unknown_session)
session_dispatcher.send(None)
schedule_task(read_report_stream(session_dispatcher))
v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1)
v1_handler.send(None)
open_session(SESSION_V1)
register_session(SESSION_V1, v1_handler)
async def read_message(session_id, *exp_types):
log.info(__name__, 'session %d: read types %s', session_id, exp_types)
signal = Signal()
if session_id == SESSION_V1:
wire_decoder = decode_wire_v1_stream(
_dispatch_and_build_protobuf, session_id, exp_types, signal)
else:
wire_decoder = decode_wire_stream(
_dispatch_and_build_protobuf, session_id, exp_types, signal)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
return await signal return await signal
async def write_message(session_id, pbuf_message): async def write(session_id, pbuf_msg):
log.info(__name__, 'session %d: write %s', session_id, pbuf_message) log.info(__name__, 'session %d: write %s', session_id, pbuf_msg)
pbuf_type = pbuf_message.__class__ pbuf_type = pbuf_msg.__class__
msg_data = pbuf_type.dumps(pbuf_message) msg_data = pbuf_type.dumps(pbuf_msg)
msg_type = pbuf_type.MESSAGE_WIRE_TYPE msg_type = pbuf_type.MESSAGE_WIRE_TYPE
sessions.get_codec(session_id).encode(
if session_id == SESSION_V1: session_id, msg_type, msg_data, _write_report)
encode_wire_v1_message(msg_type, msg_data, report_writer)
else:
encode_wire_message(msg_type, msg_data, session_id, report_writer)
async def reply_message(session_id, pbuf_message, *exp_types): async def call(session_id, pbuf_msg, *response_types):
await write_message(session_id, pbuf_message) await write(session_id, pbuf_msg)
return await read_message(session_id, *exp_types) return await read(session_id, *response_types)
class FailureError(Exception): class FailureError(Exception):
@ -113,94 +66,107 @@ class FailureError(Exception):
def to_protobuf(self): def to_protobuf(self):
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
return Failure(code=self.args[0], return Failure(code=self.args[0], message=self.args[1])
message=self.args[1])
async def monitor_workflow(workflow, session_id): def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
wf = callback(session_id, pbuf_msg, *args)
wf = _wrap_protobuf_workflow(wf, session_id)
workflow.start(wf)
async def _wrap_protobuf_workflow(wf, session_id):
try: try:
result = await workflow result = await wf
except FailureError as e: except FailureError as e:
await write_message(session_id, e.to_protobuf()) await write(session_id, e.to_protobuf())
raise raise
except Exception as e: except Exception as e:
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError from trezor.messages.FailureType import FirmwareError
await write_message(session_id, await write(session_id, Failure(
Failure(code=FirmwareError, code=FirmwareError, message='Firmware Error'))
message='Firmware Error'))
raise raise
else: else:
if result is not None: if result is not None:
await write_message(session_id, result) await write(session_id, result)
return result return result
finally: finally:
if session_id in _opened_sessions: if session_id in sessions.opened:
if session_id == SESSION_V1: sessions.listen(session_id, _handle_workflow)
wire_decoder = decode_wire_v1_stream(
_handle_registered_type, session_id)
else:
wire_decoder = decode_wire_stream(
_handle_registered_type, session_id)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
def protobuf_handler(msg_type, data_len, session_id, callback, *args): def _build_protobuf(msg_type, callback, *args):
def finalizer(message): pbuf_type = messages.get_protobuf_type(msg_type)
workflow = callback(message, session_id, *args) builder = protobuf.build_message(pbuf_type, callback, *args)
monitored = monitor_workflow(workflow, session_id)
start_workflow(monitored)
pbuf_type = get_protobuf_type(msg_type)
builder = build_protobuf_message(pbuf_type, finalizer)
builder.send(None) builder.send(None)
return pbuf_type.load(target=builder) return pbuf_type.load(target=builder)
def _handle_open_session(): def _handle_response(session_id, msg_type, data_len, response_types, signal):
session_id = open_session() if msg_type in response_types:
wire_decoder = decode_wire_stream(_handle_registered_type, session_id) return _build_protobuf(msg_type, signal.send)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
encode_session_open_message(session_id, report_writer)
def _handle_close_session(session_id):
close_session(session_id)
encode_session_close_message(session_id, report_writer)
def _handle_unknown_session(session_id, report_data):
pass # TODO
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, signal):
if msg_type in exp_types:
pbuf_type = get_protobuf_type(msg_type)
builder = build_protobuf_message(pbuf_type, signal.send)
builder.send(None)
return pbuf_type.load(target=builder)
else: else:
from trezor.messages.FailureType import UnexpectedMessage from trezor.messages.FailureType import UnexpectedMessage
signal.send(FailureError(UnexpectedMessage, 'Unexpected message')) signal.send(FailureError(UnexpectedMessage, 'Unexpected message'))
return _handle_registered_type(msg_type, data_len, session_id) return _handle_workflow(session_id, msg_type, data_len)
def _handle_registered_type(msg_type, data_len, session_id): def _handle_workflow(session_id, msg_type, data_len):
fallback = (_handle_unexpected_type, ()) if msg_type in _workflow_callbacks:
genfunc, args = _workflow_genfuncs.get(msg_type, fallback) args = _workflow_args[msg_type]
return genfunc(msg_type, data_len, session_id, *args) callback = _workflow_callbacks[msg_type]
return callback(session_id, msg_type, data_len, *args)
else:
return _handle_unexpected(session_id, msg_type, data_len)
def _handle_unexpected_type(msg_type, data_len, session_id): def _handle_unexpected(session_id, msg_type, data_len):
log.warning(__name__, 'session %d: skip type %d, len %d', log.warning(
session_id, msg_type, data_len) __name__, 'session %d: skip type %d, len %d', session_id, msg_type, data_len)
try: try:
while True: while True:
yield yield
except EOFError: except EOFError:
pass pass
def _write_report(report):
if __debug__:
log.info(__name__, 'write report %s', ubinascii.hexlify(report))
msg.send(_interface, report)
def _dispatch_reports():
while True:
report, = yield loop.Select(_interface)
report = memoryview(report)
if __debug__:
log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
sessions.dispatch(
report, _session_open, _session_close, _session_unknown)
def _session_open(session_id=None):
session_id = sessions.open(session_id)
sessions.listen(session_id, _handle_workflow)
sessions.get_codec(session_id).encode_session_open(
session_id, _write_report)
def _session_close(session_id):
sessions.close(session_id)
sessions.get_codec(session_id).encode_session_close(
session_id, _write_report)
def _session_unknown(session_id, report_data):
pass

View File

@ -1,50 +1,50 @@
from micropython import const from micropython import const
import ustruct import ustruct
SESSION_V1 = const(0) SESSION = const(0)
REP_MARKER_V1 = const(63) # ord('?') REP_MARKER = const(63) # ord('?')
REP_MARKER_V1_LEN = const(1) # len('?') REP_MARKER_LEN = const(1) # len('?')
_REP_LEN = const(64) _REP_LEN = const(64)
_MSG_HEADER_MAGIC = const(35) # org('#') _MSG_HEADER_MAGIC = const(35) # org('#')
_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length _MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
_MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1) _MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
def detect_v1(data): def detect(data):
return (data[0] == REP_MARKER_V1) return data[0] == REP_MARKER
def parse_report_v1(data): def parse_report(data):
if len(data) != _REP_LEN: if len(data) != _REP_LEN:
raise ValueError('Invalid buffer size') raise ValueError('Invalid buffer size')
return None, SESSION_V1, data[1:] return None, SESSION, data[1:]
def parse_message(data): def parse_message(data):
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data) magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise ValueError('Corrupted magic bytes') raise ValueError('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:] return msg_type, data_len, data[_MSG_HEADER_LEN:]
def serialize_message_header(data, msg_type, msg_len): def serialize_message_header(data, msg_type, msg_len):
if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN: if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
raise ValueError('Invalid buffer size') raise ValueError('Invalid buffer size')
if msg_type < 0 or msg_type > 65535: if msg_type < 0 or msg_type > 65535:
raise ValueError('Value is out of range') raise ValueError('Value is out of range')
ustruct.pack_into( ustruct.pack_into(
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN, _MSG_HEADER, data, REP_MARKER_LEN,
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
def decode_wire_v1_stream(genfunc, session_id, *args): def decode_stream(session_id, callback, *args):
'''Decode a v1 wire message from the report data and stream it to target. '''Decode a v1 wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message `callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data. data. Throws `EOFError` to target after last data chunk.
Throws `EOFError` to target after last data chunk.
Pass report payloads as `memoryview` for cheaper slicing. Pass report payloads as `memoryview` for cheaper slicing.
''' '''
@ -52,7 +52,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
message = yield # read first report message = yield # read first report
msg_type, data_len, data = parse_message(message) msg_type, data_len, data = parse_message(message)
target = genfunc(msg_type, data_len, session_id, *args) target = callback(session_id, msg_type, data_len, *args)
target.send(None) target.send(None)
while data_len > 0: while data_len > 0:
@ -68,18 +68,18 @@ Pass report payloads as `memoryview` for cheaper slicing.
target.throw(EOFError()) target.throw(EOFError())
def encode_wire_v1_message(msg_type, msg_data, target): def encode(session_id, msg_type, msg_data, callback):
'''Encode a full v1 wire message directly to reports and stream it to target. '''Encode a full v1 wire message directly to reports and stream it to callback.
Target receives `memoryview`s of HID reports which are valid until the targets Callback receives `memoryview`s of HID reports which are valid until the
`send()` method returns. callback returns.
''' '''
report = memoryview(bytearray(_REP_LEN)) report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER_V1 report[0] = REP_MARKER
serialize_message_header(report, msg_type, len(msg_data)) serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data) source_data = memoryview(msg_data)
target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:] target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
while True: while True:
# move as much as possible from source to target # move as much as possible from source to target
@ -95,10 +95,20 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data[x] = 0 target_data[x] = 0
x += 1 x += 1
target.send(report) callback(report)
if not source_data: if not source_data:
break break
# reset to skip the magic, not the whole header anymore # reset to skip the magic, not the whole header anymore
target_data = report[REP_MARKER_V1_LEN:] target_data = report[REP_MARKER_LEN:]
def encode_session_open(session_id, callback):
# v1 codec does not have explicit session support
pass
def encode_session_close(session_id, callback):
# v1 codec does not have explicit session support
pass

View File

@ -81,11 +81,11 @@ class MessageChecksumError(Exception):
pass pass
def decode_wire_stream(genfunc, session_id, *args): def decode_stream(session_id, callback, *args):
'''Decode a wire message from the report data and stream it to target. '''Decode a wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message `callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data. data.
Throws `EOFError` to target after last data chunk, in case of valid checksum. Throws `EOFError` to target after last data chunk, in case of valid checksum.
Throws `MessageChecksumError` to target if data doesn't match the checksum. Throws `MessageChecksumError` to target if data doesn't match the checksum.
@ -95,7 +95,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
message = yield # read first report message = yield # read first report
msg_type, data_len, data_tail = parse_message(message) msg_type, data_len, data_tail = parse_message(message)
target = genfunc(msg_type, data_len, session_id, *args) target = callback(session_id, msg_type, data_len, *args)
target.send(None) target.send(None)
checksum = 0 # crc32 checksum = 0 # crc32
@ -126,11 +126,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
target.throw(EOFError()) target.throw(EOFError())
def encode_wire_message(msg_type, msg_data, session_id, target): def encode(session_id, msg_type, msg_data, callback):
'''Encode a full wire message directly to reports and stream it to target. '''Encode a full wire message directly to reports and stream it to callback.
Target receives `memoryview`s of HID reports which are valid until the targets Callback receives `memoryview`s of HID reports which are valid until the
`send()` method returns. callback returns.
''' '''
report = memoryview(bytearray(_REP_LEN)) report = memoryview(bytearray(_REP_LEN))
serialize_report_header(report, REP_MARKER_HEADER, session_id) serialize_report_header(report, REP_MARKER_HEADER, session_id)
@ -166,7 +166,7 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data[x] = 0 target_data[x] = 0
x += 1 x += 1
target.send(report) callback(report)
if not source_data and not msg_footer: if not source_data and not msg_footer:
break break
@ -178,13 +178,13 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data = report[_REP_HEADER_LEN:] target_data = report[_REP_HEADER_LEN:]
def encode_session_open_message(session_id, target): def encode_session_open(session_id, callback):
report = bytearray(_REP_LEN) report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_OPEN, session_id) serialize_report_header(report, REP_MARKER_OPEN, session_id)
target.send(report) callback(report)
def encode_session_close_message(session_id, target): def encode_session_close(session_id, callback):
report = bytearray(_REP_LEN) report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_CLOSE, session_id) serialize_report_header(report, REP_MARKER_CLOSE, session_id)
target.send(report) callback(report)

View File

@ -1,46 +0,0 @@
from trezor import log
from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
from .codec_v1 import detect_v1, parse_report_v1
def dispatch_reports_by_session(handlers,
open_callback,
close_callback,
unknown_callback):
'''
Consumes reports adhering to the wire codec and dispatches the report
payloads by between the passed handlers.
'''
while True:
data = (yield)
if detect_v1(data):
marker, session_id, report_data = parse_report_v1(data)
else:
marker, session_id, report_data = parse_report(data)
if marker == REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
continue
elif marker == REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %d', session_id)
close_callback(session_id)
continue
elif session_id not in handlers:
log.debug(__name__, 'report on unknown session %d', session_id)
unknown_callback(session_id, report_data)
continue
log.debug(__name__, 'report on session %d', session_id)
handler = handlers[session_id]
try:
handler.send(report_data)
except StopIteration:
handlers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -1,19 +0,0 @@
from micropython import const
from ubinascii import hexlify
from trezor import msg, loop, log
_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface
def read_report_stream(target, iface=_DEFAULT_IFACE):
while True:
report, = yield loop.Select(iface)
log.debug(__name__, 'read report %s', hexlify(report))
target.send(memoryview(report))
def write_report_stream(iface=_DEFAULT_IFACE):
while True:
report = yield
log.info(__name__, 'write report %s', hexlify(report))
msg.send(iface, report)

View File

@ -0,0 +1,82 @@
from trezor import log
from trezor.crypto import random
from . import codec_v1
from . import codec_v2
opened = set() # opened session ids
readers = {} # session id -> generator
def generate():
while True:
session_id = random.uniform(0xffffffff) + 1
if session_id not in opened:
return session_id
def open(session_id=None):
if session_id is None:
session_id = generate()
log.info(__name__, 'session %d: open', session_id)
opened.add(session_id)
return session_id
def close(session_id):
log.info(__name__, 'session %d: close', session_id)
opened.discard(session_id)
readers.pop(session_id, None)
def get_codec(session_id):
if session_id == codec_v1.SESSION:
return codec_v1
else:
return codec_v2
def listen(session_id, handler, *args):
if session_id not in opened:
raise KeyError('Session %d is unknown' % session_id)
if session_id in readers:
raise KeyError('Session %d is already being listened on' % session_id)
log.info(__name__, 'session %d: listening', session_id)
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
decoder.send(None)
readers[session_id] = decoder
def dispatch(report, open_callback, close_callback, unknown_callback):
'''
Dispatches payloads of reports adhering to one of the wire codecs.
'''
if codec_v1.detect(report):
marker, session_id, report_data = codec_v1.parse_report(report)
else:
marker, session_id, report_data = codec_v2.parse_report(report)
if marker == codec_v2.REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
return
elif marker == codec_v2.REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %d', session_id)
close_callback(session_id)
return
if session_id not in readers:
log.warning(__name__, 'report on unknown session %d', session_id)
unknown_callback(session_id, report_data)
return
log.debug(__name__, 'report on session %d', session_id)
reader = readers[session_id]
try:
reader.send(report_data)
except StopIteration:
readers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -0,0 +1,111 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxAck import TxAck
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
class TestSignTx(unittest.TestCase):
# pylint: disable=C0301
def test_one_one_fee(self):
# tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882
# input 0: 0.0039 BTC
coin_bitcoin = coins.by_name('Bitcoin')
ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1)
pinp1 = TxInputType(script_sig=unhexlify('483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'),
prev_hash=unhexlify('c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'),
prev_index=1,
script_type=None,
sequence=None)
pinp2 = TxInputType(script_sig=unhexlify('48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'),
prev_hash=unhexlify('1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'),
prev_index=1,
script_type=None,
sequence=None)
pout1 = TxOutputBinType(script_pubkey=unhexlify('76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'),
amount=390000,
address_n=None)
inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e
# amount=390000,
prev_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'),
prev_index=0,
script_type=None,
sequence=None)
out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1',
amount=390000 - 10000,
script_type=OutputScriptType.PAYTOADDRESS,
address_n=None)
tx = SignTx(coin_name=None, version=None, lock_time=None, inputs_count=1, outputs_count=1)
messages = [
None,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp1])),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp2])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin_bitcoin),
True,
signing.UiConfirmTotal(380000, 10000, coin_bitcoin),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
signature_index=0,
signature=unhexlify('30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'),
serialized_tx=unhexlify('010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType(
signature_index=None,
signature=None,
serialized_tx=unhexlify('0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'),
)),
]
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -13,16 +13,16 @@ class TestWireCodecV1(unittest.TestCase):
def test_detect(self): def test_detect(self):
for i in range(0, 256): for i in range(0, 256):
if i == ord(b'?'): if i == ord(b'?'):
self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
else: else:
self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
def test_parse(self): def test_parse(self):
d = bytes(range(0, 55)) d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m r = b'?' + m
rm, rs, rd = codec_v1.parse_report_v1(r) rm, rs, rd = codec_v1.parse_report(r)
self.assertEqual(rm, None) self.assertEqual(rm, None)
self.assertEqual(rs, 0) self.assertEqual(rs, 0)
self.assertEqual(rd, m) self.assertEqual(rd, m)
@ -35,7 +35,7 @@ class TestWireCodecV1(unittest.TestCase):
for i in range(0, 1024): for i in range(0, 1024):
if i != 64: if i != 64:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec_v1.parse_report_v1(bytes(range(0, i))) codec_v1.parse_report(bytes(range(0, i)))
for hx in range(0, 256): for hx in range(0, 256):
for hy in range(0, 256): for hy in range(0, 256):
@ -61,8 +61,8 @@ class TestWireCodecV1(unittest.TestCase):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
record = [] record = []
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -78,8 +78,8 @@ class TestWireCodecV1(unittest.TestCase):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
record = [] record = []
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -103,8 +103,8 @@ class TestWireCodecV1(unittest.TestCase):
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))] message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
record = [] record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
res = 1 res = 1
@ -124,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
codec_v1.encode_wire_v1_message(0xabcd, b'', target) codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
self.assertEqual(len(record), 1) self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
@ -135,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
codec_v1.encode_wire_v1_message(0xabcd, data, target) codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self): def test_encode_generated_range(self):
@ -158,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase):
target = genfunc() target = genfunc()
target.send(None) target.send(None)
codec_v1.encode_wire_v1_message(msg_type, data, target) codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
self.assertEqual(received, len(reports)) self.assertEqual(received, len(reports))
def _record(self, record, *_args): def _record(self, record, *_args):

View File

@ -6,7 +6,7 @@ import ubinascii
from trezor.crypto import random from trezor.crypto import random
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec from trezor.wire import codec_v2
class TestWireCodec(unittest.TestCase): class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301 # pylint: disable=C0301
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
def test_parse(self): def test_parse(self):
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
m, s, d = codec.parse_report(d) m, s, d = codec_v2.parse_report(d)
self.assertEqual(m, b'O'[0]) self.assertEqual(m, b'O'[0])
self.assertEqual(s, 0x01234567) self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59))) self.assertEqual(d, bytes(range(0, 59)))
t, l, d = codec.parse_message(d) t, l, d = codec_v2.parse_message(d)
self.assertEqual(t, 0x00010203) self.assertEqual(t, 0x00010203)
self.assertEqual(l, 0x04050607) self.assertEqual(l, 0x04050607)
self.assertEqual(d, bytes(range(8, 59))) self.assertEqual(d, bytes(range(8, 59)))
f, = codec.parse_message_footer(d[0:4]) f, = codec_v2.parse_message_footer(d[0:4])
self.assertEqual(f, 0x08090a0b) self.assertEqual(f, 0x08090a0b)
for i in range(0, 1024): for i in range(0, 1024):
if i != 64: if i != 64:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.parse_report(bytes(range(0, i))) codec_v2.parse_report(bytes(range(0, i)))
if i != 59: if i != 59:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.parse_message(bytes(range(0, i))) codec_v2.parse_message(bytes(range(0, i)))
if i != 4: if i != 4:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.parse_message_footer(bytes(range(0, i))) codec_v2.parse_message_footer(bytes(range(0, i)))
def test_serialize(self): def test_serialize(self):
data = bytearray(range(0, 6)) data = bytearray(range(0, 6))
codec.serialize_report_header(data, 0x12, 0x3456789a) codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 6)) data = bytearray(range(0, 6))
codec.serialize_opened_session(data, 0x3456789a) codec_v2.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 14)) data = bytearray(range(0, 14))
codec.serialize_message_header(data, 0x01234567, 0x89abcdef) codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
data = bytearray(range(0, 5)) data = bytearray(range(0, 5))
codec.serialize_message_footer(data, 0x89abcdef) codec_v2.serialize_message_footer(data, 0x89abcdef)
self.assertEqual(data, b'\x89\xab\xcd\xef\x04') self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
for i in range(0, 13): for i in range(0, 13):
data = bytearray(i) data = bytearray(i)
if i < 4: if i < 4:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.serialize_message_footer(data, 0x00) codec_v2.serialize_message_footer(data, 0x00)
if i < 5: if i < 5:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.serialize_report_header(data, 0x00, 0x00) codec_v2.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.serialize_opened_session(data, 0x00) codec_v2.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
codec.serialize_message_header(data, 0x00, 0x00) codec_v2.serialize_message_header(data, 0x00, 0x00)
def test_decode_empty(self): def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -90,8 +90,8 @@ class TestWireCodec(unittest.TestCase):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -109,8 +109,8 @@ class TestWireCodec(unittest.TestCase):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
self.assertEqual(res, None) self.assertEqual(res, None)
self.assertEqual(len(record), 2) self.assertEqual(len(record), 2)
self.assertEqual(record[0], data) self.assertEqual(record[0], data)
self.assertIsInstance(record[1], codec.MessageChecksumError) self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
def test_decode_generated_range(self): def test_decode_generated_range(self):
for data_len in range(1, 512): for data_len in range(1, 512):
@ -136,8 +136,8 @@ class TestWireCodec(unittest.TestCase):
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))] message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
record = [] record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None) decoder.send(None)
res = 1 res = 1
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
self.assertEqual(len(record), 1) self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self): def test_encode_generated_range(self):
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
target = genfunc() target = genfunc()
target.send(None) target.send(None)
codec.encode_wire_message(msg_type, data, session_id, target) codec_v2.encode(session_id, msg_type, data, target.send)
self.assertEqual(received, len(reports)) self.assertEqual(received, len(reports))
def _record(self, record, *_args): def _record(self, record, *_args):