1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-17 03:48:09 +00:00

switch to v2 wire protocol

- sessions
- crc32 checksum

TODO: tests
TODO: python-trezor implementation
TODO: dispatching
This commit is contained in:
Jan Pochyla 2016-07-19 17:09:51 +02:00 committed by Pavol Rusnak
parent 545e93d1b4
commit cb0f5e2595
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
2 changed files with 74 additions and 43 deletions

View File

@ -16,6 +16,6 @@ def unregister(mtype):
def dispatch(): def dispatch():
mtype, mbuf = yield from wire.read_wire_msg() _, mtype, mbuf = yield from wire.read_wire_msg()
handler = message_handlers[mtype] handler = message_handlers[mtype]
layout.change(handler(mtype, mbuf)) layout.change(handler(mtype, mbuf))

View File

@ -1,82 +1,113 @@
import ustruct import ustruct
import ubinascii
from . import msg from . import msg
from . import loop from . import loop
from . import log from . import log
IFACE = const(0) IFACE = const(0)
REPORT_LEN = const(64) # TREZOR wire protocol v2:
REPORT_NUM = const(63) #
HEADER_MAGIC = const(35) # # HID report = 64 bytes, padded with 0x0
# First report = !SSSSTTTTLLLLD...
# Next reports = #SSSSD...CCCC
#
# S = session id
# T = message type
# L = data length
# D = data
# C = data checksum - crc32
_REPORT_LEN = const(64)
_MAX_DATA_LEN = const(65536)
_HEADER_MAGIC = const(35) # ord('#')
_DATA_MAGIC = const(33) # ord('!')
def read_report(): def _read_report():
rep, = yield loop.Select(IFACE) rep, = yield loop.Select(IFACE)
assert rep[0] == REPORT_NUM, 'Report number malformed' assert len(rep) == _REPORT_LEN, 'HID read failed'
return rep return memoryview(rep)
def write_report(rep): def _write_report(rep):
size = msg.send(IFACE, rep) size = msg.send(IFACE, rep)
assert size == REPORT_LEN, 'HID write failed' assert size == _REPORT_LEN, 'HID write failed'
yield # write_report is a generator for the sake of consistency yield # just to be a generator
def read_wire_msg(): def read_wire_msg():
rep = yield from read_report()
assert rep[1] == HEADER_MAGIC
assert rep[2] == HEADER_MAGIC
(mtype, mlen) = ustruct.unpack_from('>HL', rep, 3)
# TODO: validate mlen for sane values rep = yield from _read_report()
assert rep[0] == _HEADER_MAGIC, 'Incorrect report magic'
rep = memoryview(rep) # Parse message header
data = rep[9:] sid, mtype, mlen = ustruct.unpack_from('>LLL', rep, 1) # Skip magic
data = data[:mlen] assert mlen < _MAX_DATA_LEN, 'Message too large to read'
mbuf = bytearray(data) # TODO: allocate mlen bytes mlen += 4 # Account for the checksum
remaining = mlen - len(mbuf) data = rep[13:][:mlen] # Skip magic and header, trim to data len
buffered = bytearray(data) # Resulting message data
remaining = mlen - len(buffered)
while remaining > 0: while remaining > 0:
rep = yield from read_report() rep = yield from _read_report()
rep = memoryview(rep) assert rep[0] == _DATA_MAGIC, 'Incorrect report magic'
data = rep[1:]
data = data[:remaining] # Compare the session IDs
mbuf.extend(data) rsid = ustruct.unpack_from('>L', rep, 1)
assert rsid == sid, 'Session ID mismatch'
data = rep[5:][:remaining] # Skip magic and session ID, trim
buffered.extend(data)
remaining -= len(data) remaining -= len(data)
return (mtype, mbuf) # Split to data and checksum
mbuf = buffered[:-4]
csum = ustruct.unpack_from('>L', buffered, -4)
# Compare the checksums
assert csum == ubinascii.crc32(mbuf), 'Message checksum mismatch'
return sid, mtype, mbuf
def write_wire_msg(mtype, mbuf): def write_wire_msg(sid, mtype, mbuf):
rep = bytearray(REPORT_LEN)
rep[0] = REPORT_NUM rep = bytearray(_REPORT_LEN)
rep[1] = HEADER_MAGIC rep[0] = _HEADER_MAGIC
rep[2] = HEADER_MAGIC ustruct.pack_into('>LLL', rep, 1, sid, mtype, len(mbuf))
ustruct.pack_into('>HL', rep, 3, mtype, len(mbuf))
rep = memoryview(rep) rep = memoryview(rep)
mbuf = memoryview(mbuf) mbuf = memoryview(mbuf)
data = rep[9:] data = rep[13:] # Skip magic and header
csum = ubinascii.crc32(mbuf)
footer = ustruct.pack('>L', csum)
while True: while True:
n = min(len(data), len(mbuf)) n = min(len(data), len(mbuf))
data[:n] = mbuf[:n] data[:n] = mbuf[:n] # Copy as much data as possible from mbuf to data
i = n mbuf = mbuf[n:] # Skip written bytes
while i < len(data): data = data[n:] # Skip written bytes
data[i] = 0
i += 1 # Continue with the footer if mbuf is empty and we have space
yield from write_report(rep) if not mbuf and data:
mbuf = mbuf[n:] mbuf = footer
continue
yield from _write_report(rep)
if not mbuf: if not mbuf:
break break
data = rep[1:]
# Reset to skip the magic and session ID
data = rep[5:]
def read(*types): def read(*types):
if __debug__: if __debug__:
log.debug(__name__, 'Reading one of %s', types) log.debug(__name__, 'Reading one of %s', types)
mtype, mbuf = yield from read_wire_msg() _, mtype, mbuf = yield from read_wire_msg()
for t in types: for t in types:
if t.wire_type == mtype: if t.wire_type == mtype:
return t.loads(mbuf) return t.loads(mbuf)
@ -89,7 +120,7 @@ def write(m):
log.debug(__name__, 'Writing %s', m) log.debug(__name__, 'Writing %s', m)
mbuf = m.dumps() mbuf = m.dumps()
mtype = m.message_type.wire_type mtype = m.message_type.wire_type
yield from write_wire_msg(mtype, mbuf) yield from write_wire_msg(0, mtype, mbuf)
def call(req, *types): def call(req, *types):