1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-13 17:00:59 +00:00

wire: refactoring

- prefer importing modules instead of module members
- session_id is always first argument
- prefer much shorter names, don't expect users to import module members
- shuffle around session-specific code
- reduce allocations
This commit is contained in:
Jan Pochyla 2016-12-08 16:14:47 +01:00
parent 0b7874ad43
commit d56dc88861
9 changed files with 394 additions and 290 deletions

View File

@ -1,109 +1,62 @@
from protobuf import build_protobuf_message
import ubinascii
import protobuf
from trezor.loop import schedule_task, Future
from trezor.crypto import random
from trezor.messages import get_protobuf_type
from trezor.workflow import start_workflow
from trezor import log
from trezor import loop
from trezor import messages
from trezor import msg
from trezor import workflow
from .io import read_report_stream, write_report_stream
from .dispatcher import dispatch_reports_by_session
from .codec import \
decode_wire_stream, encode_wire_message, \
encode_session_open_message, encode_session_close_message
from .codec_v1 import \
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
from . import codec_v1
from . import codec_v2
from . import sessions
_session_handlers = {} # session id -> generator
_workflow_genfuncs = {} # wire type -> (generator function, args)
_opened_sessions = set() # session ids
_interface = None
# TODO: get rid of this, use callbacks instead
report_writer = write_report_stream()
report_writer.send(None)
_workflow_callbacks = {} # wire type -> function returning workflow
_workflow_args = {} # wire type -> args
def generate_session_id():
while True:
session_id = random.uniform(0xffffffff) + 1
if session_id not in _opened_sessions:
return session_id
def register(wire_type, callback, *args):
if wire_type in _workflow_callbacks:
raise KeyError('Message %d already registered' % wire_type)
_workflow_callbacks[wire_type] = callback
_workflow_args[wire_type] = args
def open_session(session_id=None):
if session_id is None:
session_id = generate_session_id()
_opened_sessions.add(session_id)
log.info(__name__, 'session %d: open', session_id)
return session_id
def setup(iface):
global _interface
# setup wire interface for reading and writing
_interface = iface
# implicitly register v1 codec on its session. v2 sessions are
# opened/closed explicitely through session control messages.
_session_open(codec_v1.SESSION)
# run session dispatcher
loop.schedule_task(_dispatch_reports())
def close_session(session_id):
_opened_sessions.discard(session_id)
_session_handlers.pop(session_id, None)
log.info(__name__, 'session %d: close', session_id)
def register_type(wire_type, genfunc, *args):
if wire_type in _workflow_genfuncs:
raise KeyError('message of type %d already registered' % wire_type)
log.info(__name__, 'register type %d', wire_type)
_workflow_genfuncs[wire_type] = (genfunc, args)
def register_session(session_id, handler):
if session_id not in _opened_sessions:
raise KeyError('session %d is unknown' % session_id)
if session_id in _session_handlers:
raise KeyError('session %d is already being listened on' % session_id)
log.info(__name__, 'session %d: listening', session_id)
_session_handlers[session_id] = handler
def setup():
session_dispatcher = dispatch_reports_by_session(
_session_handlers,
_handle_open_session,
_handle_close_session,
_handle_unknown_session)
session_dispatcher.send(None)
schedule_task(read_report_stream(session_dispatcher))
v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1)
v1_handler.send(None)
open_session(SESSION_V1)
register_session(SESSION_V1, v1_handler)
async def read_message(session_id, *exp_types):
log.info(__name__, 'session %d: read types %s', session_id, exp_types)
signal = Signal()
if session_id == SESSION_V1:
wire_decoder = decode_wire_v1_stream(
_dispatch_and_build_protobuf, session_id, exp_types, signal)
else:
wire_decoder = decode_wire_stream(
_dispatch_and_build_protobuf, session_id, exp_types, signal)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
async def read(session_id, *wire_types):
log.info(__name__, 'session %d: read types %s', session_id, wire_types)
signal = loop.Signal()
sessions.listen(session_id, _handle_response, wire_types, signal)
return await signal
async def write_message(session_id, pbuf_message):
log.info(__name__, 'session %d: write %s', session_id, pbuf_message)
pbuf_type = pbuf_message.__class__
msg_data = pbuf_type.dumps(pbuf_message)
async def write(session_id, pbuf_msg):
log.info(__name__, 'session %d: write %s', session_id, pbuf_msg)
pbuf_type = pbuf_msg.__class__
msg_data = pbuf_type.dumps(pbuf_msg)
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
if session_id == SESSION_V1:
encode_wire_v1_message(msg_type, msg_data, report_writer)
else:
encode_wire_message(msg_type, msg_data, session_id, report_writer)
sessions.get_codec(session_id).encode(
session_id, msg_type, msg_data, _write_report)
async def reply_message(session_id, pbuf_message, *exp_types):
await write_message(session_id, pbuf_message)
return await read_message(session_id, *exp_types)
async def call(session_id, pbuf_msg, *response_types):
await write(session_id, pbuf_msg)
return await read(session_id, *response_types)
class FailureError(Exception):
@ -113,94 +66,107 @@ class FailureError(Exception):
def to_protobuf(self):
from trezor.messages.Failure import Failure
return Failure(code=self.args[0],
message=self.args[1])
return Failure(code=self.args[0], message=self.args[1])
async def monitor_workflow(workflow, session_id):
def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
wf = callback(session_id, pbuf_msg, *args)
wf = _wrap_protobuf_workflow(wf, session_id)
workflow.start(wf)
async def _wrap_protobuf_workflow(wf, session_id):
try:
result = await workflow
result = await wf
except FailureError as e:
await write_message(session_id, e.to_protobuf())
await write(session_id, e.to_protobuf())
raise
except Exception as e:
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
await write_message(session_id,
Failure(code=FirmwareError,
message='Firmware Error'))
await write(session_id, Failure(
code=FirmwareError, message='Firmware Error'))
raise
else:
if result is not None:
await write_message(session_id, result)
await write(session_id, result)
return result
finally:
if session_id in _opened_sessions:
if session_id == SESSION_V1:
wire_decoder = decode_wire_v1_stream(
_handle_registered_type, session_id)
else:
wire_decoder = decode_wire_stream(
_handle_registered_type, session_id)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
if session_id in sessions.opened:
sessions.listen(session_id, _handle_workflow)
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
def finalizer(message):
workflow = callback(message, session_id, *args)
monitored = monitor_workflow(workflow, session_id)
start_workflow(monitored)
pbuf_type = get_protobuf_type(msg_type)
builder = build_protobuf_message(pbuf_type, finalizer)
def _build_protobuf(msg_type, callback, *args):
pbuf_type = messages.get_protobuf_type(msg_type)
builder = protobuf.build_message(pbuf_type, callback, *args)
builder.send(None)
return pbuf_type.load(target=builder)
def _handle_open_session():
session_id = open_session()
wire_decoder = decode_wire_stream(_handle_registered_type, session_id)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
encode_session_open_message(session_id, report_writer)
def _handle_close_session(session_id):
close_session(session_id)
encode_session_close_message(session_id, report_writer)
def _handle_unknown_session(session_id, report_data):
pass # TODO
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, signal):
if msg_type in exp_types:
pbuf_type = get_protobuf_type(msg_type)
builder = build_protobuf_message(pbuf_type, signal.send)
builder.send(None)
return pbuf_type.load(target=builder)
def _handle_response(session_id, msg_type, data_len, response_types, signal):
if msg_type in response_types:
return _build_protobuf(msg_type, signal.send)
else:
from trezor.messages.FailureType import UnexpectedMessage
signal.send(FailureError(UnexpectedMessage, 'Unexpected message'))
return _handle_registered_type(msg_type, data_len, session_id)
return _handle_workflow(session_id, msg_type, data_len)
def _handle_registered_type(msg_type, data_len, session_id):
fallback = (_handle_unexpected_type, ())
genfunc, args = _workflow_genfuncs.get(msg_type, fallback)
return genfunc(msg_type, data_len, session_id, *args)
def _handle_workflow(session_id, msg_type, data_len):
if msg_type in _workflow_callbacks:
args = _workflow_args[msg_type]
callback = _workflow_callbacks[msg_type]
return callback(session_id, msg_type, data_len, *args)
else:
return _handle_unexpected(session_id, msg_type, data_len)
def _handle_unexpected_type(msg_type, data_len, session_id):
log.warning(__name__, 'session %d: skip type %d, len %d',
session_id, msg_type, data_len)
def _handle_unexpected(session_id, msg_type, data_len):
log.warning(
__name__, 'session %d: skip type %d, len %d', session_id, msg_type, data_len)
try:
while True:
yield
except EOFError:
pass
def _write_report(report):
if __debug__:
log.info(__name__, 'write report %s', ubinascii.hexlify(report))
msg.send(_interface, report)
def _dispatch_reports():
while True:
report, = yield loop.Select(_interface)
report = memoryview(report)
if __debug__:
log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
sessions.dispatch(
report, _session_open, _session_close, _session_unknown)
def _session_open(session_id=None):
session_id = sessions.open(session_id)
sessions.listen(session_id, _handle_workflow)
sessions.get_codec(session_id).encode_session_open(
session_id, _write_report)
def _session_close(session_id):
sessions.close(session_id)
sessions.get_codec(session_id).encode_session_close(
session_id, _write_report)
def _session_unknown(session_id, report_data):
pass

View File

@ -1,50 +1,50 @@
from micropython import const
import ustruct
SESSION_V1 = const(0)
REP_MARKER_V1 = const(63) # ord('?')
REP_MARKER_V1_LEN = const(1) # len('?')
SESSION = const(0)
REP_MARKER = const(63) # ord('?')
REP_MARKER_LEN = const(1) # len('?')
_REP_LEN = const(64)
_MSG_HEADER_MAGIC = const(35) # org('#')
_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length
_MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1)
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
def detect_v1(data):
return (data[0] == REP_MARKER_V1)
def detect(data):
return data[0] == REP_MARKER
def parse_report_v1(data):
def parse_report(data):
if len(data) != _REP_LEN:
raise ValueError('Invalid buffer size')
return None, SESSION_V1, data[1:]
return None, SESSION, data[1:]
def parse_message(data):
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise ValueError('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
return msg_type, data_len, data[_MSG_HEADER_LEN:]
def serialize_message_header(data, msg_type, msg_len):
if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
raise ValueError('Invalid buffer size')
if msg_type < 0 or msg_type > 65535:
raise ValueError('Value is out of range')
ustruct.pack_into(
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
_MSG_HEADER, data, REP_MARKER_LEN,
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
def decode_wire_v1_stream(genfunc, session_id, *args):
def decode_stream(session_id, callback, *args):
'''Decode a v1 wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk.
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data. Throws `EOFError` to target after last data chunk.
Pass report payloads as `memoryview` for cheaper slicing.
'''
@ -52,7 +52,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
message = yield # read first report
msg_type, data_len, data = parse_message(message)
target = genfunc(msg_type, data_len, session_id, *args)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
while data_len > 0:
@ -68,18 +68,18 @@ Pass report payloads as `memoryview` for cheaper slicing.
target.throw(EOFError())
def encode_wire_v1_message(msg_type, msg_data, target):
'''Encode a full v1 wire message directly to reports and stream it to target.
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full v1 wire message directly to reports and stream it to callback.
Target receives `memoryview`s of HID reports which are valid until the targets
`send()` method returns.
'''
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
'''
report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER_V1
report[0] = REP_MARKER
serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data)
target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:]
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
while True:
# move as much as possible from source to target
@ -95,10 +95,20 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data[x] = 0
x += 1
target.send(report)
callback(report)
if not source_data:
break
# reset to skip the magic, not the whole header anymore
target_data = report[REP_MARKER_V1_LEN:]
target_data = report[REP_MARKER_LEN:]
def encode_session_open(session_id, callback):
# v1 codec does not have explicit session support
pass
def encode_session_close(session_id, callback):
# v1 codec does not have explicit session support
pass

View File

@ -81,11 +81,11 @@ class MessageChecksumError(Exception):
pass
def decode_wire_stream(genfunc, session_id, *args):
def decode_stream(session_id, callback, *args):
'''Decode a wire message from the report data and stream it to target.
Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk, in case of valid checksum.
Throws `MessageChecksumError` to target if data doesn't match the checksum.
@ -95,7 +95,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
message = yield # read first report
msg_type, data_len, data_tail = parse_message(message)
target = genfunc(msg_type, data_len, session_id, *args)
target = callback(session_id, msg_type, data_len, *args)
target.send(None)
checksum = 0 # crc32
@ -126,11 +126,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
target.throw(EOFError())
def encode_wire_message(msg_type, msg_data, session_id, target):
'''Encode a full wire message directly to reports and stream it to target.
def encode(session_id, msg_type, msg_data, callback):
'''Encode a full wire message directly to reports and stream it to callback.
Target receives `memoryview`s of HID reports which are valid until the targets
`send()` method returns.
Callback receives `memoryview`s of HID reports which are valid until the
callback returns.
'''
report = memoryview(bytearray(_REP_LEN))
serialize_report_header(report, REP_MARKER_HEADER, session_id)
@ -166,7 +166,7 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data[x] = 0
x += 1
target.send(report)
callback(report)
if not source_data and not msg_footer:
break
@ -178,13 +178,13 @@ Target receives `memoryview`s of HID reports which are valid until the targets
target_data = report[_REP_HEADER_LEN:]
def encode_session_open_message(session_id, target):
def encode_session_open(session_id, callback):
report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_OPEN, session_id)
target.send(report)
callback(report)
def encode_session_close_message(session_id, target):
def encode_session_close(session_id, callback):
report = bytearray(_REP_LEN)
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
target.send(report)
callback(report)

View File

@ -1,46 +0,0 @@
from trezor import log
from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
from .codec_v1 import detect_v1, parse_report_v1
def dispatch_reports_by_session(handlers,
open_callback,
close_callback,
unknown_callback):
'''
Consumes reports adhering to the wire codec and dispatches the report
payloads by between the passed handlers.
'''
while True:
data = (yield)
if detect_v1(data):
marker, session_id, report_data = parse_report_v1(data)
else:
marker, session_id, report_data = parse_report(data)
if marker == REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
continue
elif marker == REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %d', session_id)
close_callback(session_id)
continue
elif session_id not in handlers:
log.debug(__name__, 'report on unknown session %d', session_id)
unknown_callback(session_id, report_data)
continue
log.debug(__name__, 'report on session %d', session_id)
handler = handlers[session_id]
try:
handler.send(report_data)
except StopIteration:
handlers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -1,19 +0,0 @@
from micropython import const
from ubinascii import hexlify
from trezor import msg, loop, log
_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface
def read_report_stream(target, iface=_DEFAULT_IFACE):
while True:
report, = yield loop.Select(iface)
log.debug(__name__, 'read report %s', hexlify(report))
target.send(memoryview(report))
def write_report_stream(iface=_DEFAULT_IFACE):
while True:
report = yield
log.info(__name__, 'write report %s', hexlify(report))
msg.send(iface, report)

View File

@ -0,0 +1,82 @@
from trezor import log
from trezor.crypto import random
from . import codec_v1
from . import codec_v2
opened = set() # opened session ids
readers = {} # session id -> generator
def generate():
while True:
session_id = random.uniform(0xffffffff) + 1
if session_id not in opened:
return session_id
def open(session_id=None):
if session_id is None:
session_id = generate()
log.info(__name__, 'session %d: open', session_id)
opened.add(session_id)
return session_id
def close(session_id):
log.info(__name__, 'session %d: close', session_id)
opened.discard(session_id)
readers.pop(session_id, None)
def get_codec(session_id):
if session_id == codec_v1.SESSION:
return codec_v1
else:
return codec_v2
def listen(session_id, handler, *args):
if session_id not in opened:
raise KeyError('Session %d is unknown' % session_id)
if session_id in readers:
raise KeyError('Session %d is already being listened on' % session_id)
log.info(__name__, 'session %d: listening', session_id)
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
decoder.send(None)
readers[session_id] = decoder
def dispatch(report, open_callback, close_callback, unknown_callback):
'''
Dispatches payloads of reports adhering to one of the wire codecs.
'''
if codec_v1.detect(report):
marker, session_id, report_data = codec_v1.parse_report(report)
else:
marker, session_id, report_data = codec_v2.parse_report(report)
if marker == codec_v2.REP_MARKER_OPEN:
log.debug(__name__, 'request for new session')
open_callback()
return
elif marker == codec_v2.REP_MARKER_CLOSE:
log.debug(__name__, 'request for closing session %d', session_id)
close_callback(session_id)
return
if session_id not in readers:
log.warning(__name__, 'report on unknown session %d', session_id)
unknown_callback(session_id, report_data)
return
log.debug(__name__, 'report on session %d', session_id)
reader = readers[session_id]
try:
reader.send(report_data)
except StopIteration:
readers.pop(session_id)
except Exception as e:
log.exception(__name__, e)

View File

@ -0,0 +1,111 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxAck import TxAck
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
class TestSignTx(unittest.TestCase):
# pylint: disable=C0301
def test_one_one_fee(self):
# tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882
# input 0: 0.0039 BTC
coin_bitcoin = coins.by_name('Bitcoin')
ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1)
pinp1 = TxInputType(script_sig=unhexlify('483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'),
prev_hash=unhexlify('c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'),
prev_index=1,
script_type=None,
sequence=None)
pinp2 = TxInputType(script_sig=unhexlify('48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'),
prev_hash=unhexlify('1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'),
prev_index=1,
script_type=None,
sequence=None)
pout1 = TxOutputBinType(script_pubkey=unhexlify('76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'),
amount=390000,
address_n=None)
inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e
# amount=390000,
prev_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'),
prev_index=0,
script_type=None,
sequence=None)
out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1',
amount=390000 - 10000,
script_type=OutputScriptType.PAYTOADDRESS,
address_n=None)
tx = SignTx(coin_name=None, version=None, lock_time=None, inputs_count=1, outputs_count=1)
messages = [
None,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp1])),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(inputs=[pinp2])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin_bitcoin),
True,
signing.UiConfirmTotal(380000, 10000, coin_bitcoin),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
signature_index=0,
signature=unhexlify('30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'),
serialized_tx=unhexlify('010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType(
signature_index=None,
signature=None,
serialized_tx=unhexlify('0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'),
)),
]
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

View File

@ -13,16 +13,16 @@ class TestWireCodecV1(unittest.TestCase):
def test_detect(self):
for i in range(0, 256):
if i == ord(b'?'):
self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
else:
self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
def test_parse(self):
d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m
rm, rs, rd = codec_v1.parse_report_v1(r)
rm, rs, rd = codec_v1.parse_report(r)
self.assertEqual(rm, None)
self.assertEqual(rs, 0)
self.assertEqual(rd, m)
@ -35,7 +35,7 @@ class TestWireCodecV1(unittest.TestCase):
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
codec_v1.parse_report_v1(bytes(range(0, i)))
codec_v1.parse_report(bytes(range(0, i)))
for hx in range(0, 256):
for hy in range(0, 256):
@ -61,8 +61,8 @@ class TestWireCodecV1(unittest.TestCase):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
record = []
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
@ -78,8 +78,8 @@ class TestWireCodecV1(unittest.TestCase):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
record = []
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
@ -103,8 +103,8 @@ class TestWireCodecV1(unittest.TestCase):
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
res = 1
@ -124,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)()
target.send(None)
codec_v1.encode_wire_v1_message(0xabcd, b'', target)
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
@ -135,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)()
target.send(None)
codec_v1.encode_wire_v1_message(0xabcd, data, target)
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self):
@ -158,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase):
target = genfunc()
target.send(None)
codec_v1.encode_wire_v1_message(msg_type, data, target)
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):

View File

@ -6,7 +6,7 @@ import ubinascii
from trezor.crypto import random
from trezor.utils import chunks
from trezor.wire import codec
from trezor.wire import codec_v2
class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
def test_parse(self):
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
m, s, d = codec.parse_report(d)
m, s, d = codec_v2.parse_report(d)
self.assertEqual(m, b'O'[0])
self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59)))
t, l, d = codec.parse_message(d)
t, l, d = codec_v2.parse_message(d)
self.assertEqual(t, 0x00010203)
self.assertEqual(l, 0x04050607)
self.assertEqual(d, bytes(range(8, 59)))
f, = codec.parse_message_footer(d[0:4])
f, = codec_v2.parse_message_footer(d[0:4])
self.assertEqual(f, 0x08090a0b)
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
codec.parse_report(bytes(range(0, i)))
codec_v2.parse_report(bytes(range(0, i)))
if i != 59:
with self.assertRaises(ValueError):
codec.parse_message(bytes(range(0, i)))
codec_v2.parse_message(bytes(range(0, i)))
if i != 4:
with self.assertRaises(ValueError):
codec.parse_message_footer(bytes(range(0, i)))
codec_v2.parse_message_footer(bytes(range(0, i)))
def test_serialize(self):
data = bytearray(range(0, 6))
codec.serialize_report_header(data, 0x12, 0x3456789a)
codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 6))
codec.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
codec_v2.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 14))
codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
data = bytearray(range(0, 5))
codec.serialize_message_footer(data, 0x89abcdef)
codec_v2.serialize_message_footer(data, 0x89abcdef)
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
for i in range(0, 13):
data = bytearray(i)
if i < 4:
with self.assertRaises(ValueError):
codec.serialize_message_footer(data, 0x00)
codec_v2.serialize_message_footer(data, 0x00)
if i < 5:
with self.assertRaises(ValueError):
codec.serialize_report_header(data, 0x00, 0x00)
codec_v2.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError):
codec.serialize_opened_session(data, 0x00)
codec_v2.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError):
codec.serialize_message_header(data, 0x00, 0x00)
codec_v2.serialize_message_header(data, 0x00, 0x00)
def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = []
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
@ -90,8 +90,8 @@ class TestWireCodec(unittest.TestCase):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
@ -109,8 +109,8 @@ class TestWireCodec(unittest.TestCase):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
try:
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], codec.MessageChecksumError)
self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
def test_decode_generated_range(self):
for data_len in range(1, 512):
@ -136,8 +136,8 @@ class TestWireCodec(unittest.TestCase):
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
decoder.send(None)
res = 1
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)()
target.send(None)
codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)()
target.send(None)
codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self):
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
target = genfunc()
target.send(None)
codec.encode_wire_message(msg_type, data, session_id, target)
codec_v2.encode(session_id, msg_type, data, target.send)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):