mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-14 17:31:04 +00:00
wire: refactoring
- prefer importing modules instead of module members - session_id is always first argument - prefer much shorter names, don't expect users to import module members - shuffle around session-specific code - reduce allocations
This commit is contained in:
parent
0b7874ad43
commit
d56dc88861
@ -1,109 +1,62 @@
|
|||||||
from protobuf import build_protobuf_message
|
import ubinascii
|
||||||
|
import protobuf
|
||||||
|
|
||||||
from trezor.loop import schedule_task, Future
|
|
||||||
from trezor.crypto import random
|
|
||||||
from trezor.messages import get_protobuf_type
|
|
||||||
from trezor.workflow import start_workflow
|
|
||||||
from trezor import log
|
from trezor import log
|
||||||
|
from trezor import loop
|
||||||
|
from trezor import messages
|
||||||
|
from trezor import msg
|
||||||
|
from trezor import workflow
|
||||||
|
|
||||||
from .io import read_report_stream, write_report_stream
|
from . import codec_v1
|
||||||
from .dispatcher import dispatch_reports_by_session
|
from . import codec_v2
|
||||||
from .codec import \
|
from . import sessions
|
||||||
decode_wire_stream, encode_wire_message, \
|
|
||||||
encode_session_open_message, encode_session_close_message
|
|
||||||
from .codec_v1 import \
|
|
||||||
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
|
|
||||||
|
|
||||||
_session_handlers = {} # session id -> generator
|
_interface = None
|
||||||
_workflow_genfuncs = {} # wire type -> (generator function, args)
|
|
||||||
_opened_sessions = set() # session ids
|
|
||||||
|
|
||||||
# TODO: get rid of this, use callbacks instead
|
_workflow_callbacks = {} # wire type -> function returning workflow
|
||||||
report_writer = write_report_stream()
|
_workflow_args = {} # wire type -> args
|
||||||
report_writer.send(None)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_session_id():
|
def register(wire_type, callback, *args):
|
||||||
while True:
|
if wire_type in _workflow_callbacks:
|
||||||
session_id = random.uniform(0xffffffff) + 1
|
raise KeyError('Message %d already registered' % wire_type)
|
||||||
if session_id not in _opened_sessions:
|
_workflow_callbacks[wire_type] = callback
|
||||||
return session_id
|
_workflow_args[wire_type] = args
|
||||||
|
|
||||||
|
|
||||||
def open_session(session_id=None):
|
def setup(iface):
|
||||||
if session_id is None:
|
global _interface
|
||||||
session_id = generate_session_id()
|
|
||||||
_opened_sessions.add(session_id)
|
# setup wire interface for reading and writing
|
||||||
log.info(__name__, 'session %d: open', session_id)
|
_interface = iface
|
||||||
return session_id
|
|
||||||
|
# implicitly register v1 codec on its session. v2 sessions are
|
||||||
|
# opened/closed explicitely through session control messages.
|
||||||
|
_session_open(codec_v1.SESSION)
|
||||||
|
|
||||||
|
# run session dispatcher
|
||||||
|
loop.schedule_task(_dispatch_reports())
|
||||||
|
|
||||||
|
|
||||||
def close_session(session_id):
|
async def read(session_id, *wire_types):
|
||||||
_opened_sessions.discard(session_id)
|
log.info(__name__, 'session %d: read types %s', session_id, wire_types)
|
||||||
_session_handlers.pop(session_id, None)
|
signal = loop.Signal()
|
||||||
log.info(__name__, 'session %d: close', session_id)
|
sessions.listen(session_id, _handle_response, wire_types, signal)
|
||||||
|
|
||||||
|
|
||||||
def register_type(wire_type, genfunc, *args):
|
|
||||||
if wire_type in _workflow_genfuncs:
|
|
||||||
raise KeyError('message of type %d already registered' % wire_type)
|
|
||||||
log.info(__name__, 'register type %d', wire_type)
|
|
||||||
_workflow_genfuncs[wire_type] = (genfunc, args)
|
|
||||||
|
|
||||||
|
|
||||||
def register_session(session_id, handler):
|
|
||||||
if session_id not in _opened_sessions:
|
|
||||||
raise KeyError('session %d is unknown' % session_id)
|
|
||||||
if session_id in _session_handlers:
|
|
||||||
raise KeyError('session %d is already being listened on' % session_id)
|
|
||||||
log.info(__name__, 'session %d: listening', session_id)
|
|
||||||
_session_handlers[session_id] = handler
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
|
||||||
session_dispatcher = dispatch_reports_by_session(
|
|
||||||
_session_handlers,
|
|
||||||
_handle_open_session,
|
|
||||||
_handle_close_session,
|
|
||||||
_handle_unknown_session)
|
|
||||||
session_dispatcher.send(None)
|
|
||||||
schedule_task(read_report_stream(session_dispatcher))
|
|
||||||
|
|
||||||
v1_handler = decode_wire_v1_stream(_handle_registered_type, SESSION_V1)
|
|
||||||
v1_handler.send(None)
|
|
||||||
open_session(SESSION_V1)
|
|
||||||
register_session(SESSION_V1, v1_handler)
|
|
||||||
|
|
||||||
|
|
||||||
async def read_message(session_id, *exp_types):
|
|
||||||
log.info(__name__, 'session %d: read types %s', session_id, exp_types)
|
|
||||||
signal = Signal()
|
|
||||||
if session_id == SESSION_V1:
|
|
||||||
wire_decoder = decode_wire_v1_stream(
|
|
||||||
_dispatch_and_build_protobuf, session_id, exp_types, signal)
|
|
||||||
else:
|
|
||||||
wire_decoder = decode_wire_stream(
|
|
||||||
_dispatch_and_build_protobuf, session_id, exp_types, signal)
|
|
||||||
wire_decoder.send(None)
|
|
||||||
register_session(session_id, wire_decoder)
|
|
||||||
return await signal
|
return await signal
|
||||||
|
|
||||||
|
|
||||||
async def write_message(session_id, pbuf_message):
|
async def write(session_id, pbuf_msg):
|
||||||
log.info(__name__, 'session %d: write %s', session_id, pbuf_message)
|
log.info(__name__, 'session %d: write %s', session_id, pbuf_msg)
|
||||||
pbuf_type = pbuf_message.__class__
|
pbuf_type = pbuf_msg.__class__
|
||||||
msg_data = pbuf_type.dumps(pbuf_message)
|
msg_data = pbuf_type.dumps(pbuf_msg)
|
||||||
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
|
msg_type = pbuf_type.MESSAGE_WIRE_TYPE
|
||||||
|
sessions.get_codec(session_id).encode(
|
||||||
if session_id == SESSION_V1:
|
session_id, msg_type, msg_data, _write_report)
|
||||||
encode_wire_v1_message(msg_type, msg_data, report_writer)
|
|
||||||
else:
|
|
||||||
encode_wire_message(msg_type, msg_data, session_id, report_writer)
|
|
||||||
|
|
||||||
|
|
||||||
async def reply_message(session_id, pbuf_message, *exp_types):
|
async def call(session_id, pbuf_msg, *response_types):
|
||||||
await write_message(session_id, pbuf_message)
|
await write(session_id, pbuf_msg)
|
||||||
return await read_message(session_id, *exp_types)
|
return await read(session_id, *response_types)
|
||||||
|
|
||||||
|
|
||||||
class FailureError(Exception):
|
class FailureError(Exception):
|
||||||
@ -113,94 +66,107 @@ class FailureError(Exception):
|
|||||||
|
|
||||||
def to_protobuf(self):
|
def to_protobuf(self):
|
||||||
from trezor.messages.Failure import Failure
|
from trezor.messages.Failure import Failure
|
||||||
return Failure(code=self.args[0],
|
return Failure(code=self.args[0], message=self.args[1])
|
||||||
message=self.args[1])
|
|
||||||
|
|
||||||
|
|
||||||
async def monitor_workflow(workflow, session_id):
|
def protobuf_workflow(session_id, msg_type, data_len, callback, *args):
|
||||||
|
return _build_protobuf(msg_type, _start_protobuf_workflow, session_id, callback, args)
|
||||||
|
|
||||||
|
|
||||||
|
def _start_protobuf_workflow(pbuf_msg, session_id, callback, args):
|
||||||
|
wf = callback(session_id, pbuf_msg, *args)
|
||||||
|
wf = _wrap_protobuf_workflow(wf, session_id)
|
||||||
|
workflow.start(wf)
|
||||||
|
|
||||||
|
|
||||||
|
async def _wrap_protobuf_workflow(wf, session_id):
|
||||||
try:
|
try:
|
||||||
result = await workflow
|
result = await wf
|
||||||
|
|
||||||
except FailureError as e:
|
except FailureError as e:
|
||||||
await write_message(session_id, e.to_protobuf())
|
await write(session_id, e.to_protobuf())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from trezor.messages.Failure import Failure
|
from trezor.messages.Failure import Failure
|
||||||
from trezor.messages.FailureType import FirmwareError
|
from trezor.messages.FailureType import FirmwareError
|
||||||
await write_message(session_id,
|
await write(session_id, Failure(
|
||||||
Failure(code=FirmwareError,
|
code=FirmwareError, message='Firmware Error'))
|
||||||
message='Firmware Error'))
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if result is not None:
|
if result is not None:
|
||||||
await write_message(session_id, result)
|
await write(session_id, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if session_id in _opened_sessions:
|
if session_id in sessions.opened:
|
||||||
if session_id == SESSION_V1:
|
sessions.listen(session_id, _handle_workflow)
|
||||||
wire_decoder = decode_wire_v1_stream(
|
|
||||||
_handle_registered_type, session_id)
|
|
||||||
else:
|
|
||||||
wire_decoder = decode_wire_stream(
|
|
||||||
_handle_registered_type, session_id)
|
|
||||||
wire_decoder.send(None)
|
|
||||||
register_session(session_id, wire_decoder)
|
|
||||||
|
|
||||||
|
|
||||||
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
|
def _build_protobuf(msg_type, callback, *args):
|
||||||
def finalizer(message):
|
pbuf_type = messages.get_protobuf_type(msg_type)
|
||||||
workflow = callback(message, session_id, *args)
|
builder = protobuf.build_message(pbuf_type, callback, *args)
|
||||||
monitored = monitor_workflow(workflow, session_id)
|
|
||||||
start_workflow(monitored)
|
|
||||||
pbuf_type = get_protobuf_type(msg_type)
|
|
||||||
builder = build_protobuf_message(pbuf_type, finalizer)
|
|
||||||
builder.send(None)
|
builder.send(None)
|
||||||
return pbuf_type.load(target=builder)
|
return pbuf_type.load(target=builder)
|
||||||
|
|
||||||
|
|
||||||
def _handle_open_session():
|
def _handle_response(session_id, msg_type, data_len, response_types, signal):
|
||||||
session_id = open_session()
|
if msg_type in response_types:
|
||||||
wire_decoder = decode_wire_stream(_handle_registered_type, session_id)
|
return _build_protobuf(msg_type, signal.send)
|
||||||
wire_decoder.send(None)
|
|
||||||
register_session(session_id, wire_decoder)
|
|
||||||
encode_session_open_message(session_id, report_writer)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_close_session(session_id):
|
|
||||||
close_session(session_id)
|
|
||||||
encode_session_close_message(session_id, report_writer)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_unknown_session(session_id, report_data):
|
|
||||||
pass # TODO
|
|
||||||
|
|
||||||
|
|
||||||
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, signal):
|
|
||||||
if msg_type in exp_types:
|
|
||||||
pbuf_type = get_protobuf_type(msg_type)
|
|
||||||
builder = build_protobuf_message(pbuf_type, signal.send)
|
|
||||||
builder.send(None)
|
|
||||||
return pbuf_type.load(target=builder)
|
|
||||||
else:
|
else:
|
||||||
from trezor.messages.FailureType import UnexpectedMessage
|
from trezor.messages.FailureType import UnexpectedMessage
|
||||||
signal.send(FailureError(UnexpectedMessage, 'Unexpected message'))
|
signal.send(FailureError(UnexpectedMessage, 'Unexpected message'))
|
||||||
return _handle_registered_type(msg_type, data_len, session_id)
|
return _handle_workflow(session_id, msg_type, data_len)
|
||||||
|
|
||||||
|
|
||||||
def _handle_registered_type(msg_type, data_len, session_id):
|
def _handle_workflow(session_id, msg_type, data_len):
|
||||||
fallback = (_handle_unexpected_type, ())
|
if msg_type in _workflow_callbacks:
|
||||||
genfunc, args = _workflow_genfuncs.get(msg_type, fallback)
|
args = _workflow_args[msg_type]
|
||||||
return genfunc(msg_type, data_len, session_id, *args)
|
callback = _workflow_callbacks[msg_type]
|
||||||
|
return callback(session_id, msg_type, data_len, *args)
|
||||||
|
else:
|
||||||
|
return _handle_unexpected(session_id, msg_type, data_len)
|
||||||
|
|
||||||
|
|
||||||
def _handle_unexpected_type(msg_type, data_len, session_id):
|
def _handle_unexpected(session_id, msg_type, data_len):
|
||||||
log.warning(__name__, 'session %d: skip type %d, len %d',
|
log.warning(
|
||||||
session_id, msg_type, data_len)
|
__name__, 'session %d: skip type %d, len %d', session_id, msg_type, data_len)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield
|
yield
|
||||||
except EOFError:
|
except EOFError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _write_report(report):
|
||||||
|
if __debug__:
|
||||||
|
log.info(__name__, 'write report %s', ubinascii.hexlify(report))
|
||||||
|
msg.send(_interface, report)
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch_reports():
|
||||||
|
while True:
|
||||||
|
report, = yield loop.Select(_interface)
|
||||||
|
report = memoryview(report)
|
||||||
|
if __debug__:
|
||||||
|
log.debug(__name__, 'read report %s', ubinascii.hexlify(report))
|
||||||
|
sessions.dispatch(
|
||||||
|
report, _session_open, _session_close, _session_unknown)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_open(session_id=None):
|
||||||
|
session_id = sessions.open(session_id)
|
||||||
|
sessions.listen(session_id, _handle_workflow)
|
||||||
|
sessions.get_codec(session_id).encode_session_open(
|
||||||
|
session_id, _write_report)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_close(session_id):
|
||||||
|
sessions.close(session_id)
|
||||||
|
sessions.get_codec(session_id).encode_session_close(
|
||||||
|
session_id, _write_report)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_unknown(session_id, report_data):
|
||||||
|
pass
|
||||||
|
@ -1,50 +1,50 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
|
|
||||||
import ustruct
|
import ustruct
|
||||||
|
|
||||||
SESSION_V1 = const(0)
|
SESSION = const(0)
|
||||||
REP_MARKER_V1 = const(63) # ord('?')
|
REP_MARKER = const(63) # ord('?')
|
||||||
REP_MARKER_V1_LEN = const(1) # len('?')
|
REP_MARKER_LEN = const(1) # len('?')
|
||||||
|
|
||||||
_REP_LEN = const(64)
|
_REP_LEN = const(64)
|
||||||
_MSG_HEADER_MAGIC = const(35) # org('#')
|
_MSG_HEADER_MAGIC = const(35) # org('#')
|
||||||
_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length
|
_MSG_HEADER = '>BBHL' # magic, magic, wire type, data length
|
||||||
_MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1)
|
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
||||||
|
|
||||||
|
|
||||||
def detect_v1(data):
|
def detect(data):
|
||||||
return (data[0] == REP_MARKER_V1)
|
return data[0] == REP_MARKER
|
||||||
|
|
||||||
|
|
||||||
def parse_report_v1(data):
|
def parse_report(data):
|
||||||
if len(data) != _REP_LEN:
|
if len(data) != _REP_LEN:
|
||||||
raise ValueError('Invalid buffer size')
|
raise ValueError('Invalid buffer size')
|
||||||
return None, SESSION_V1, data[1:]
|
return None, SESSION, data[1:]
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data):
|
def parse_message(data):
|
||||||
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
|
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
||||||
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
||||||
raise ValueError('Corrupted magic bytes')
|
raise ValueError('Corrupted magic bytes')
|
||||||
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
|
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_header(data, msg_type, msg_len):
|
def serialize_message_header(data, msg_type, msg_len):
|
||||||
if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:
|
if len(data) < REP_MARKER_LEN + _MSG_HEADER_LEN:
|
||||||
raise ValueError('Invalid buffer size')
|
raise ValueError('Invalid buffer size')
|
||||||
if msg_type < 0 or msg_type > 65535:
|
if msg_type < 0 or msg_type > 65535:
|
||||||
raise ValueError('Value is out of range')
|
raise ValueError('Value is out of range')
|
||||||
ustruct.pack_into(
|
ustruct.pack_into(
|
||||||
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
|
_MSG_HEADER, data, REP_MARKER_LEN,
|
||||||
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
||||||
|
|
||||||
|
|
||||||
def decode_wire_v1_stream(genfunc, session_id, *args):
|
def decode_stream(session_id, callback, *args):
|
||||||
'''Decode a v1 wire message from the report data and stream it to target.
|
'''Decode a v1 wire message from the report data and stream it to target.
|
||||||
|
|
||||||
Receives report payloads. After first report, creates target by calling
|
Receives report payloads. After first report, creates target by calling
|
||||||
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
|
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
||||||
data.
|
data. Throws `EOFError` to target after last data chunk.
|
||||||
Throws `EOFError` to target after last data chunk.
|
|
||||||
|
|
||||||
Pass report payloads as `memoryview` for cheaper slicing.
|
Pass report payloads as `memoryview` for cheaper slicing.
|
||||||
'''
|
'''
|
||||||
@ -52,7 +52,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
message = yield # read first report
|
message = yield # read first report
|
||||||
msg_type, data_len, data = parse_message(message)
|
msg_type, data_len, data = parse_message(message)
|
||||||
|
|
||||||
target = genfunc(msg_type, data_len, session_id, *args)
|
target = callback(session_id, msg_type, data_len, *args)
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
while data_len > 0:
|
while data_len > 0:
|
||||||
@ -68,18 +68,18 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
target.throw(EOFError())
|
target.throw(EOFError())
|
||||||
|
|
||||||
|
|
||||||
def encode_wire_v1_message(msg_type, msg_data, target):
|
def encode(session_id, msg_type, msg_data, callback):
|
||||||
'''Encode a full v1 wire message directly to reports and stream it to target.
|
'''Encode a full v1 wire message directly to reports and stream it to callback.
|
||||||
|
|
||||||
Target receives `memoryview`s of HID reports which are valid until the targets
|
Callback receives `memoryview`s of HID reports which are valid until the
|
||||||
`send()` method returns.
|
callback returns.
|
||||||
'''
|
'''
|
||||||
report = memoryview(bytearray(_REP_LEN))
|
report = memoryview(bytearray(_REP_LEN))
|
||||||
report[0] = REP_MARKER_V1
|
report[0] = REP_MARKER
|
||||||
serialize_message_header(report, msg_type, len(msg_data))
|
serialize_message_header(report, msg_type, len(msg_data))
|
||||||
|
|
||||||
source_data = memoryview(msg_data)
|
source_data = memoryview(msg_data)
|
||||||
target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:]
|
target_data = report[REP_MARKER_LEN + _MSG_HEADER_LEN:]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# move as much as possible from source to target
|
# move as much as possible from source to target
|
||||||
@ -95,10 +95,20 @@ Target receives `memoryview`s of HID reports which are valid until the targets
|
|||||||
target_data[x] = 0
|
target_data[x] = 0
|
||||||
x += 1
|
x += 1
|
||||||
|
|
||||||
target.send(report)
|
callback(report)
|
||||||
|
|
||||||
if not source_data:
|
if not source_data:
|
||||||
break
|
break
|
||||||
|
|
||||||
# reset to skip the magic, not the whole header anymore
|
# reset to skip the magic, not the whole header anymore
|
||||||
target_data = report[REP_MARKER_V1_LEN:]
|
target_data = report[REP_MARKER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_session_open(session_id, callback):
|
||||||
|
# v1 codec does not have explicit session support
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def encode_session_close(session_id, callback):
|
||||||
|
# v1 codec does not have explicit session support
|
||||||
|
pass
|
@ -81,11 +81,11 @@ class MessageChecksumError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def decode_wire_stream(genfunc, session_id, *args):
|
def decode_stream(session_id, callback, *args):
|
||||||
'''Decode a wire message from the report data and stream it to target.
|
'''Decode a wire message from the report data and stream it to target.
|
||||||
|
|
||||||
Receives report payloads. After first report, creates target by calling
|
Receives report payloads. After first report, creates target by calling
|
||||||
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
|
`callback(session_id, msg_type, data_len, *args)` and sends chunks of message
|
||||||
data.
|
data.
|
||||||
Throws `EOFError` to target after last data chunk, in case of valid checksum.
|
Throws `EOFError` to target after last data chunk, in case of valid checksum.
|
||||||
Throws `MessageChecksumError` to target if data doesn't match the checksum.
|
Throws `MessageChecksumError` to target if data doesn't match the checksum.
|
||||||
@ -95,7 +95,7 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
message = yield # read first report
|
message = yield # read first report
|
||||||
msg_type, data_len, data_tail = parse_message(message)
|
msg_type, data_len, data_tail = parse_message(message)
|
||||||
|
|
||||||
target = genfunc(msg_type, data_len, session_id, *args)
|
target = callback(session_id, msg_type, data_len, *args)
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
checksum = 0 # crc32
|
checksum = 0 # crc32
|
||||||
@ -126,11 +126,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
target.throw(EOFError())
|
target.throw(EOFError())
|
||||||
|
|
||||||
|
|
||||||
def encode_wire_message(msg_type, msg_data, session_id, target):
|
def encode(session_id, msg_type, msg_data, callback):
|
||||||
'''Encode a full wire message directly to reports and stream it to target.
|
'''Encode a full wire message directly to reports and stream it to callback.
|
||||||
|
|
||||||
Target receives `memoryview`s of HID reports which are valid until the targets
|
Callback receives `memoryview`s of HID reports which are valid until the
|
||||||
`send()` method returns.
|
callback returns.
|
||||||
'''
|
'''
|
||||||
report = memoryview(bytearray(_REP_LEN))
|
report = memoryview(bytearray(_REP_LEN))
|
||||||
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
||||||
@ -166,7 +166,7 @@ Target receives `memoryview`s of HID reports which are valid until the targets
|
|||||||
target_data[x] = 0
|
target_data[x] = 0
|
||||||
x += 1
|
x += 1
|
||||||
|
|
||||||
target.send(report)
|
callback(report)
|
||||||
|
|
||||||
if not source_data and not msg_footer:
|
if not source_data and not msg_footer:
|
||||||
break
|
break
|
||||||
@ -178,13 +178,13 @@ Target receives `memoryview`s of HID reports which are valid until the targets
|
|||||||
target_data = report[_REP_HEADER_LEN:]
|
target_data = report[_REP_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
def encode_session_open_message(session_id, target):
|
def encode_session_open(session_id, callback):
|
||||||
report = bytearray(_REP_LEN)
|
report = bytearray(_REP_LEN)
|
||||||
serialize_report_header(report, REP_MARKER_OPEN, session_id)
|
serialize_report_header(report, REP_MARKER_OPEN, session_id)
|
||||||
target.send(report)
|
callback(report)
|
||||||
|
|
||||||
|
|
||||||
def encode_session_close_message(session_id, target):
|
def encode_session_close(session_id, callback):
|
||||||
report = bytearray(_REP_LEN)
|
report = bytearray(_REP_LEN)
|
||||||
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
|
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
|
||||||
target.send(report)
|
callback(report)
|
@ -1,46 +0,0 @@
|
|||||||
from trezor import log
|
|
||||||
from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
|
|
||||||
from .codec_v1 import detect_v1, parse_report_v1
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch_reports_by_session(handlers,
|
|
||||||
open_callback,
|
|
||||||
close_callback,
|
|
||||||
unknown_callback):
|
|
||||||
'''
|
|
||||||
Consumes reports adhering to the wire codec and dispatches the report
|
|
||||||
payloads by between the passed handlers.
|
|
||||||
'''
|
|
||||||
|
|
||||||
while True:
|
|
||||||
|
|
||||||
data = (yield)
|
|
||||||
if detect_v1(data):
|
|
||||||
marker, session_id, report_data = parse_report_v1(data)
|
|
||||||
else:
|
|
||||||
marker, session_id, report_data = parse_report(data)
|
|
||||||
|
|
||||||
if marker == REP_MARKER_OPEN:
|
|
||||||
log.debug(__name__, 'request for new session')
|
|
||||||
open_callback()
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif marker == REP_MARKER_CLOSE:
|
|
||||||
log.debug(__name__, 'request for closing session %d', session_id)
|
|
||||||
close_callback(session_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif session_id not in handlers:
|
|
||||||
log.debug(__name__, 'report on unknown session %d', session_id)
|
|
||||||
unknown_callback(session_id, report_data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
log.debug(__name__, 'report on session %d', session_id)
|
|
||||||
handler = handlers[session_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
handler.send(report_data)
|
|
||||||
except StopIteration:
|
|
||||||
handlers.pop(session_id)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(__name__, e)
|
|
@ -1,19 +0,0 @@
|
|||||||
from micropython import const
|
|
||||||
from ubinascii import hexlify
|
|
||||||
from trezor import msg, loop, log
|
|
||||||
|
|
||||||
_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface
|
|
||||||
|
|
||||||
|
|
||||||
def read_report_stream(target, iface=_DEFAULT_IFACE):
|
|
||||||
while True:
|
|
||||||
report, = yield loop.Select(iface)
|
|
||||||
log.debug(__name__, 'read report %s', hexlify(report))
|
|
||||||
target.send(memoryview(report))
|
|
||||||
|
|
||||||
|
|
||||||
def write_report_stream(iface=_DEFAULT_IFACE):
|
|
||||||
while True:
|
|
||||||
report = yield
|
|
||||||
log.info(__name__, 'write report %s', hexlify(report))
|
|
||||||
msg.send(iface, report)
|
|
82
src/trezor/wire/sessions.py
Normal file
82
src/trezor/wire/sessions.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from trezor import log
|
||||||
|
from trezor.crypto import random
|
||||||
|
|
||||||
|
from . import codec_v1
|
||||||
|
from . import codec_v2
|
||||||
|
|
||||||
|
opened = set() # opened session ids
|
||||||
|
readers = {} # session id -> generator
|
||||||
|
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
while True:
|
||||||
|
session_id = random.uniform(0xffffffff) + 1
|
||||||
|
if session_id not in opened:
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def open(session_id=None):
|
||||||
|
if session_id is None:
|
||||||
|
session_id = generate()
|
||||||
|
log.info(__name__, 'session %d: open', session_id)
|
||||||
|
opened.add(session_id)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def close(session_id):
|
||||||
|
log.info(__name__, 'session %d: close', session_id)
|
||||||
|
opened.discard(session_id)
|
||||||
|
readers.pop(session_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_codec(session_id):
|
||||||
|
if session_id == codec_v1.SESSION:
|
||||||
|
return codec_v1
|
||||||
|
else:
|
||||||
|
return codec_v2
|
||||||
|
|
||||||
|
|
||||||
|
def listen(session_id, handler, *args):
|
||||||
|
if session_id not in opened:
|
||||||
|
raise KeyError('Session %d is unknown' % session_id)
|
||||||
|
if session_id in readers:
|
||||||
|
raise KeyError('Session %d is already being listened on' % session_id)
|
||||||
|
log.info(__name__, 'session %d: listening', session_id)
|
||||||
|
decoder = get_codec(session_id).decode_stream(session_id, handler, *args)
|
||||||
|
decoder.send(None)
|
||||||
|
readers[session_id] = decoder
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch(report, open_callback, close_callback, unknown_callback):
|
||||||
|
'''
|
||||||
|
Dispatches payloads of reports adhering to one of the wire codecs.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if codec_v1.detect(report):
|
||||||
|
marker, session_id, report_data = codec_v1.parse_report(report)
|
||||||
|
else:
|
||||||
|
marker, session_id, report_data = codec_v2.parse_report(report)
|
||||||
|
|
||||||
|
if marker == codec_v2.REP_MARKER_OPEN:
|
||||||
|
log.debug(__name__, 'request for new session')
|
||||||
|
open_callback()
|
||||||
|
return
|
||||||
|
elif marker == codec_v2.REP_MARKER_CLOSE:
|
||||||
|
log.debug(__name__, 'request for closing session %d', session_id)
|
||||||
|
close_callback(session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if session_id not in readers:
|
||||||
|
log.warning(__name__, 'report on unknown session %d', session_id)
|
||||||
|
unknown_callback(session_id, report_data)
|
||||||
|
return
|
||||||
|
|
||||||
|
log.debug(__name__, 'report on session %d', session_id)
|
||||||
|
reader = readers[session_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader.send(report_data)
|
||||||
|
except StopIteration:
|
||||||
|
readers.pop(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(__name__, e)
|
111
tests/test_apps.wallet.signtx.py
Normal file
111
tests/test_apps.wallet.signtx.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from common import *
|
||||||
|
|
||||||
|
from trezor.utils import chunks
|
||||||
|
from trezor.crypto import bip32, bip39
|
||||||
|
from trezor.messages.SignTx import SignTx
|
||||||
|
from trezor.messages.TxInputType import TxInputType
|
||||||
|
from trezor.messages.TxOutputType import TxOutputType
|
||||||
|
from trezor.messages.TxOutputBinType import TxOutputBinType
|
||||||
|
from trezor.messages.TxRequest import TxRequest
|
||||||
|
from trezor.messages.TxAck import TxAck
|
||||||
|
from trezor.messages.TransactionType import TransactionType
|
||||||
|
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
||||||
|
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||||
|
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||||
|
from trezor.messages import OutputScriptType
|
||||||
|
|
||||||
|
from apps.common import coins
|
||||||
|
from apps.wallet.sign_tx import signing
|
||||||
|
|
||||||
|
class TestSignTx(unittest.TestCase):
|
||||||
|
# pylint: disable=C0301
|
||||||
|
|
||||||
|
def test_one_one_fee(self):
|
||||||
|
# tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882
|
||||||
|
# input 0: 0.0039 BTC
|
||||||
|
|
||||||
|
coin_bitcoin = coins.by_name('Bitcoin')
|
||||||
|
|
||||||
|
ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1)
|
||||||
|
pinp1 = TxInputType(script_sig=unhexlify('483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'),
|
||||||
|
prev_hash=unhexlify('c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'),
|
||||||
|
prev_index=1,
|
||||||
|
script_type=None,
|
||||||
|
sequence=None)
|
||||||
|
pinp2 = TxInputType(script_sig=unhexlify('48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'),
|
||||||
|
prev_hash=unhexlify('1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'),
|
||||||
|
prev_index=1,
|
||||||
|
script_type=None,
|
||||||
|
sequence=None)
|
||||||
|
pout1 = TxOutputBinType(script_pubkey=unhexlify('76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'),
|
||||||
|
amount=390000,
|
||||||
|
address_n=None)
|
||||||
|
|
||||||
|
inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e
|
||||||
|
# amount=390000,
|
||||||
|
prev_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'),
|
||||||
|
prev_index=0,
|
||||||
|
script_type=None,
|
||||||
|
sequence=None)
|
||||||
|
out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1',
|
||||||
|
amount=390000 - 10000,
|
||||||
|
script_type=OutputScriptType.PAYTOADDRESS,
|
||||||
|
address_n=None)
|
||||||
|
tx = SignTx(coin_name=None, version=None, lock_time=None, inputs_count=1, outputs_count=1)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
None,
|
||||||
|
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
|
||||||
|
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||||
|
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||||
|
TxAck(tx=ptx1),
|
||||||
|
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(inputs=[pinp1])),
|
||||||
|
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(inputs=[pinp2])),
|
||||||
|
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(bin_outputs=[pout1])),
|
||||||
|
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(outputs=[out1])),
|
||||||
|
signing.UiConfirmOutput(out1, coin_bitcoin),
|
||||||
|
True,
|
||||||
|
signing.UiConfirmTotal(380000, 10000, coin_bitcoin),
|
||||||
|
True,
|
||||||
|
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
|
||||||
|
# ButtonRequest(code=ButtonRequest_SignTx),
|
||||||
|
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(inputs=[inp1])),
|
||||||
|
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||||
|
TxAck(tx=TransactionType(outputs=[out1])),
|
||||||
|
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
|
||||||
|
signature_index=0,
|
||||||
|
signature=unhexlify('30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'),
|
||||||
|
serialized_tx=unhexlify('010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))),
|
||||||
|
TxAck(tx=TransactionType(outputs=[out1])),
|
||||||
|
TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType(
|
||||||
|
signature_index=None,
|
||||||
|
signature=None,
|
||||||
|
serialized_tx=unhexlify('0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'),
|
||||||
|
)),
|
||||||
|
]
|
||||||
|
|
||||||
|
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
|
||||||
|
root = bip32.from_seed(seed, 'secp256k1')
|
||||||
|
|
||||||
|
signer = signing.sign_tx(tx, root)
|
||||||
|
for request, response in chunks(messages, 2):
|
||||||
|
self.assertEqualEx(signer.send(request), response)
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
signer.send(None)
|
||||||
|
|
||||||
|
def assertEqualEx(self, a, b):
|
||||||
|
# hack to avoid adding __eq__ to signing.Ui* classes
|
||||||
|
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
|
||||||
|
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
|
||||||
|
return self.assertEqual(a.__dict__, b.__dict__)
|
||||||
|
else:
|
||||||
|
return self.assertEqual(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -13,16 +13,16 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
def test_detect(self):
|
def test_detect(self):
|
||||||
for i in range(0, 256):
|
for i in range(0, 256):
|
||||||
if i == ord(b'?'):
|
if i == ord(b'?'):
|
||||||
self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
|
self.assertTrue(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
||||||
else:
|
else:
|
||||||
self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
|
self.assertFalse(codec_v1.detect(bytes([i]) + b'\x00' * 63))
|
||||||
|
|
||||||
def test_parse(self):
|
def test_parse(self):
|
||||||
d = bytes(range(0, 55))
|
d = bytes(range(0, 55))
|
||||||
m = b'##\x00\x00\x00\x00\x00\x37' + d
|
m = b'##\x00\x00\x00\x00\x00\x37' + d
|
||||||
r = b'?' + m
|
r = b'?' + m
|
||||||
|
|
||||||
rm, rs, rd = codec_v1.parse_report_v1(r)
|
rm, rs, rd = codec_v1.parse_report(r)
|
||||||
self.assertEqual(rm, None)
|
self.assertEqual(rm, None)
|
||||||
self.assertEqual(rs, 0)
|
self.assertEqual(rs, 0)
|
||||||
self.assertEqual(rd, m)
|
self.assertEqual(rd, m)
|
||||||
@ -35,7 +35,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
for i in range(0, 1024):
|
for i in range(0, 1024):
|
||||||
if i != 64:
|
if i != 64:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec_v1.parse_report_v1(bytes(range(0, i)))
|
codec_v1.parse_report(bytes(range(0, i)))
|
||||||
|
|
||||||
for hx in range(0, 256):
|
for hx in range(0, 256):
|
||||||
for hy in range(0, 256):
|
for hy in range(0, 256):
|
||||||
@ -61,8 +61,8 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
|
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 0, 'dummy')
|
||||||
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -78,8 +78,8 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
|
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, 0xabcd, 55, 'dummy')
|
||||||
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -103,8 +103,8 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
|
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
||||||
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v1.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
res = 1
|
res = 1
|
||||||
@ -124,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
target = self._record(record)()
|
target = self._record(record)()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec_v1.encode_wire_v1_message(0xabcd, b'', target)
|
codec_v1.encode(codec_v1.SESSION, 0xabcd, b'', target.send)
|
||||||
self.assertEqual(len(record), 1)
|
self.assertEqual(len(record), 1)
|
||||||
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
|
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
target = self._record(record)()
|
target = self._record(record)()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec_v1.encode_wire_v1_message(0xabcd, data, target)
|
codec_v1.encode(codec_v1.SESSION, 0xabcd, data, target.send)
|
||||||
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
|
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
|
||||||
|
|
||||||
def test_encode_generated_range(self):
|
def test_encode_generated_range(self):
|
||||||
@ -158,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
target = genfunc()
|
target = genfunc()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec_v1.encode_wire_v1_message(msg_type, data, target)
|
codec_v1.encode(codec_v1.SESSION, msg_type, data, target.send)
|
||||||
self.assertEqual(received, len(reports))
|
self.assertEqual(received, len(reports))
|
||||||
|
|
||||||
def _record(self, record, *_args):
|
def _record(self, record, *_args):
|
||||||
|
@ -6,7 +6,7 @@ import ubinascii
|
|||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
|
|
||||||
from trezor.wire import codec
|
from trezor.wire import codec_v2
|
||||||
|
|
||||||
class TestWireCodec(unittest.TestCase):
|
class TestWireCodec(unittest.TestCase):
|
||||||
# pylint: disable=C0301
|
# pylint: disable=C0301
|
||||||
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
def test_parse(self):
|
def test_parse(self):
|
||||||
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
||||||
|
|
||||||
m, s, d = codec.parse_report(d)
|
m, s, d = codec_v2.parse_report(d)
|
||||||
self.assertEqual(m, b'O'[0])
|
self.assertEqual(m, b'O'[0])
|
||||||
self.assertEqual(s, 0x01234567)
|
self.assertEqual(s, 0x01234567)
|
||||||
self.assertEqual(d, bytes(range(0, 59)))
|
self.assertEqual(d, bytes(range(0, 59)))
|
||||||
|
|
||||||
t, l, d = codec.parse_message(d)
|
t, l, d = codec_v2.parse_message(d)
|
||||||
self.assertEqual(t, 0x00010203)
|
self.assertEqual(t, 0x00010203)
|
||||||
self.assertEqual(l, 0x04050607)
|
self.assertEqual(l, 0x04050607)
|
||||||
self.assertEqual(d, bytes(range(8, 59)))
|
self.assertEqual(d, bytes(range(8, 59)))
|
||||||
|
|
||||||
f, = codec.parse_message_footer(d[0:4])
|
f, = codec_v2.parse_message_footer(d[0:4])
|
||||||
self.assertEqual(f, 0x08090a0b)
|
self.assertEqual(f, 0x08090a0b)
|
||||||
|
|
||||||
for i in range(0, 1024):
|
for i in range(0, 1024):
|
||||||
if i != 64:
|
if i != 64:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.parse_report(bytes(range(0, i)))
|
codec_v2.parse_report(bytes(range(0, i)))
|
||||||
if i != 59:
|
if i != 59:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.parse_message(bytes(range(0, i)))
|
codec_v2.parse_message(bytes(range(0, i)))
|
||||||
if i != 4:
|
if i != 4:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.parse_message_footer(bytes(range(0, i)))
|
codec_v2.parse_message_footer(bytes(range(0, i)))
|
||||||
|
|
||||||
def test_serialize(self):
|
def test_serialize(self):
|
||||||
data = bytearray(range(0, 6))
|
data = bytearray(range(0, 6))
|
||||||
codec.serialize_report_header(data, 0x12, 0x3456789a)
|
codec_v2.serialize_report_header(data, 0x12, 0x3456789a)
|
||||||
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
||||||
|
|
||||||
data = bytearray(range(0, 6))
|
data = bytearray(range(0, 6))
|
||||||
codec.serialize_opened_session(data, 0x3456789a)
|
codec_v2.serialize_opened_session(data, 0x3456789a)
|
||||||
self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
self.assertEqual(data, bytes([codec_v2.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
||||||
|
|
||||||
data = bytearray(range(0, 14))
|
data = bytearray(range(0, 14))
|
||||||
codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
codec_v2.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
||||||
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
||||||
|
|
||||||
data = bytearray(range(0, 5))
|
data = bytearray(range(0, 5))
|
||||||
codec.serialize_message_footer(data, 0x89abcdef)
|
codec_v2.serialize_message_footer(data, 0x89abcdef)
|
||||||
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
||||||
|
|
||||||
for i in range(0, 13):
|
for i in range(0, 13):
|
||||||
data = bytearray(i)
|
data = bytearray(i)
|
||||||
if i < 4:
|
if i < 4:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.serialize_message_footer(data, 0x00)
|
codec_v2.serialize_message_footer(data, 0x00)
|
||||||
if i < 5:
|
if i < 5:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.serialize_report_header(data, 0x00, 0x00)
|
codec_v2.serialize_report_header(data, 0x00, 0x00)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.serialize_opened_session(data, 0x00)
|
codec_v2.serialize_opened_session(data, 0x00)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
codec.serialize_message_header(data, 0x00, 0x00)
|
codec_v2.serialize_message_header(data, 0x00, 0x00)
|
||||||
|
|
||||||
def test_decode_empty(self):
|
def test_decode_empty(self):
|
||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 0, 'dummy')
|
||||||
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -90,8 +90,8 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
||||||
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -109,8 +109,8 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, 0xabcdef12, 47, 'dummy')
|
||||||
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
self.assertEqual(res, None)
|
self.assertEqual(res, None)
|
||||||
self.assertEqual(len(record), 2)
|
self.assertEqual(len(record), 2)
|
||||||
self.assertEqual(record[0], data)
|
self.assertEqual(record[0], data)
|
||||||
self.assertIsInstance(record[1], codec.MessageChecksumError)
|
self.assertIsInstance(record[1], codec_v2.MessageChecksumError)
|
||||||
|
|
||||||
def test_decode_generated_range(self):
|
def test_decode_generated_range(self):
|
||||||
for data_len in range(1, 512):
|
for data_len in range(1, 512):
|
||||||
@ -136,8 +136,8 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
|
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
|
||||||
|
|
||||||
record = []
|
record = []
|
||||||
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
genfunc = self._record(record, 0xdeadbeef, msg_type, data_len, 'dummy')
|
||||||
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
decoder = codec_v2.decode_stream(0xdeadbeef, genfunc, 'dummy')
|
||||||
decoder.send(None)
|
decoder.send(None)
|
||||||
|
|
||||||
res = 1
|
res = 1
|
||||||
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
target = self._record(record)()
|
target = self._record(record)()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
|
codec_v2.encode(0xdeadbeef, 0xabcdef12, b'', target.send)
|
||||||
self.assertEqual(len(record), 1)
|
self.assertEqual(len(record), 1)
|
||||||
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
target = self._record(record)()
|
target = self._record(record)()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
|
codec_v2.encode(0xdeadbeef, 0xabcdef12, data, target.send)
|
||||||
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
||||||
|
|
||||||
def test_encode_generated_range(self):
|
def test_encode_generated_range(self):
|
||||||
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
|
|||||||
target = genfunc()
|
target = genfunc()
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
codec.encode_wire_message(msg_type, data, session_id, target)
|
codec_v2.encode(session_id, msg_type, data, target.send)
|
||||||
self.assertEqual(received, len(reports))
|
self.assertEqual(received, len(reports))
|
||||||
|
|
||||||
def _record(self, record, *_args):
|
def _record(self, record, *_args):
|
Loading…
Reference in New Issue
Block a user