diff --git a/src/trezor/dispatcher.py b/src/trezor/dispatcher.py index 966918fd42..8ab2e2c68a 100644 --- a/src/trezor/dispatcher.py +++ b/src/trezor/dispatcher.py @@ -16,6 +16,6 @@ def unregister(mtype): def dispatch(): - mtype, mbuf = yield from wire.read_wire_msg() + _, mtype, mbuf = yield from wire.read_wire_msg() handler = message_handlers[mtype] layout.change(handler(mtype, mbuf)) diff --git a/src/trezor/wire.py b/src/trezor/wire.py index db1df5bf0f..36397e5440 100644 --- a/src/trezor/wire.py +++ b/src/trezor/wire.py @@ -1,82 +1,113 @@ import ustruct +import ubinascii from . import msg from . import loop from . import log IFACE = const(0) -REPORT_LEN = const(64) -REPORT_NUM = const(63) -HEADER_MAGIC = const(35) # +# TREZOR wire protocol v2: +# +# HID report = 64 bytes, padded with 0x0 +# First report = !SSSSTTTTLLLLD... +# Next reports = #SSSSD...CCCC +# +# S = session id +# T = message type +# L = data length +# D = data +# C = data checksum - crc32 + +_REPORT_LEN = const(64) +_MAX_DATA_LEN = const(65536) +_HEADER_MAGIC = const(35) # ord('#') +_DATA_MAGIC = const(33) # ord('!') -def read_report(): +def _read_report(): rep, = yield loop.Select(IFACE) - assert rep[0] == REPORT_NUM, 'Report number malformed' - return rep + assert len(rep) == _REPORT_LEN, 'HID read failed' + return memoryview(rep) -def write_report(rep): +def _write_report(rep): size = msg.send(IFACE, rep) - assert size == REPORT_LEN, 'HID write failed' - yield # write_report is a generator for the sake of consistency + assert size == _REPORT_LEN, 'HID write failed' + yield # just to be a generator def read_wire_msg(): - rep = yield from read_report() - assert rep[1] == HEADER_MAGIC - assert rep[2] == HEADER_MAGIC - (mtype, mlen) = ustruct.unpack_from('>HL', rep, 3) - # TODO: validate mlen for sane values + rep = yield from _read_report() + assert rep[0] == _HEADER_MAGIC, 'Incorrect report magic' - rep = memoryview(rep) - data = rep[9:] - data = data[:mlen] + # Parse message header + sid, mtype, mlen = ustruct.unpack_from('>LLL', rep, 1) # Skip magic + assert mlen < _MAX_DATA_LEN, 'Message too large to read' - mbuf = bytearray(data) # TODO: allocate mlen bytes - remaining = mlen - len(mbuf) + mlen += 4 # Account for the checksum + data = rep[13:][:mlen] # Skip magic and header, trim to data len + buffered = bytearray(data) # Resulting message data + remaining = mlen - len(buffered) while remaining > 0: - rep = yield from read_report() - rep = memoryview(rep) - data = rep[1:] - data = data[:remaining] - mbuf.extend(data) + rep = yield from _read_report() + assert rep[0] == _DATA_MAGIC, 'Incorrect report magic' + + # Compare the session IDs + rsid = ustruct.unpack_from('>L', rep, 1) + assert rsid == sid, 'Session ID mismatch' + + data = rep[5:][:remaining] # Skip magic and session ID, trim + buffered.extend(data) remaining -= len(data) - return (mtype, mbuf) + # Split to data and checksum + mbuf = buffered[:-4] + csum = ustruct.unpack_from('>L', buffered, -4) + + # Compare the checksums + assert csum == ubinascii.crc32(mbuf), 'Message checksum mismatch' + + return sid, mtype, mbuf -def write_wire_msg(mtype, mbuf): - rep = bytearray(REPORT_LEN) - rep[0] = REPORT_NUM - rep[1] = HEADER_MAGIC - rep[2] = HEADER_MAGIC - ustruct.pack_into('>HL', rep, 3, mtype, len(mbuf)) +def write_wire_msg(sid, mtype, mbuf): + + rep = bytearray(_REPORT_LEN) + rep[0] = _HEADER_MAGIC + ustruct.pack_into('>LLL', rep, 1, sid, mtype, len(mbuf)) rep = memoryview(rep) mbuf = memoryview(mbuf) - data = rep[9:] + data = rep[13:] # Skip magic and header + + csum = ubinascii.crc32(mbuf) + footer = ustruct.pack('>L', csum) while True: n = min(len(data), len(mbuf)) - data[:n] = mbuf[:n] - i = n - while i < len(data): - data[i] = 0 - i += 1 - yield from write_report(rep) - mbuf = mbuf[n:] + data[:n] = mbuf[:n] # Copy as much data as possible from mbuf to data + mbuf = mbuf[n:] # Skip written bytes + data = data[n:] # Skip written bytes + + # Continue with the footer if mbuf is empty and we have space + if not mbuf and data: + mbuf = footer + continue + + yield from _write_report(rep) if not mbuf: break - data = rep[1:] + + # Reset to skip the magic and session ID + data = rep[5:] def read(*types): if __debug__: log.debug(__name__, 'Reading one of %s', types) - mtype, mbuf = yield from read_wire_msg() + _, mtype, mbuf = yield from read_wire_msg() for t in types: if t.wire_type == mtype: return t.loads(mbuf) @@ -89,7 +120,7 @@ def write(m): log.debug(__name__, 'Writing %s', m) mbuf = m.dumps() mtype = m.message_type.wire_type - yield from write_wire_msg(mtype, mbuf) + yield from write_wire_msg(0, mtype, mbuf) def call(req, *types):