mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-27 12:39:04 +00:00
add a broken PoC of wire/protobuf stream reading
This commit is contained in:
parent
d6592a48e6
commit
3eaf001eef
266
src/lib/protobuf/protobuf_streaming.py
Normal file
266
src/lib/protobuf/protobuf_streaming.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
'''Streaming protobuf codec.
|
||||||
|
|
||||||
|
Handles asynchronous encoding and decoding of protobuf value streams.
|
||||||
|
|
||||||
|
Value format: ((field_type, field_flags, field_name), field_value)
|
||||||
|
field_type: Either one of UVarintType, BoolType, BytesType, UnicodeType,
|
||||||
|
or an instance of EmbeddedMessage.
|
||||||
|
field_flags (int): Field bit flags `FLAG_REQUIRED`, `FLAG_REPEATED`.
|
||||||
|
field_name (str): Field name string.
|
||||||
|
field_value: Depends on field_type. EmbeddedMessage has `field_value == None`.
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def build_protobuf_message(message_type, future):
|
||||||
|
message = message_type()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
field, field_value = yield
|
||||||
|
field_type, field_flags, field_name = field
|
||||||
|
if not _is_scalar_type(field_type):
|
||||||
|
field_value = yield from build_protobuf_message(field_type, future)
|
||||||
|
if field_flags & FLAG_REPEATED:
|
||||||
|
field_value = getattr(
|
||||||
|
message, field_name, []).append(field_value)
|
||||||
|
setattr(message, field_name, field_value)
|
||||||
|
except EOFError:
|
||||||
|
future.resolve(message)
|
||||||
|
|
||||||
|
|
||||||
|
def print_protobuf_message(message_type):
|
||||||
|
print('OPEN', message_type)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
field, field_value = yield
|
||||||
|
field_type, _, field_name = field
|
||||||
|
if not _is_scalar_type(field_type):
|
||||||
|
yield from print_protobuf_message(field_type)
|
||||||
|
else:
|
||||||
|
print('FIELD', field_name, field_type, field_value)
|
||||||
|
except EOFError:
|
||||||
|
print('CLOSE', message_type)
|
||||||
|
|
||||||
|
|
||||||
|
class UVarintType:
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(target, value):
|
||||||
|
shifted_value = True
|
||||||
|
while shifted_value:
|
||||||
|
shifted_value = value >> 7
|
||||||
|
yield from target.write(chr((value & 0x7F) | (
|
||||||
|
0x80 if shifted_value != 0 else 0x00)))
|
||||||
|
value = shifted_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(source):
|
||||||
|
value, shift, quantum = 0, 0, 0x80
|
||||||
|
while (quantum & 0x80) == 0x80:
|
||||||
|
data = yield from source.read(1)
|
||||||
|
quantum = ord(data)
|
||||||
|
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class BoolType:
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(target, value):
|
||||||
|
yield from target.write('\x01' if value else '\x00')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(source):
|
||||||
|
varint = yield from UVarintType.load(source)
|
||||||
|
return varint != 0
|
||||||
|
|
||||||
|
|
||||||
|
class BytesType:
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(target, value):
|
||||||
|
yield from UVarintType.dump(target, len(value))
|
||||||
|
yield from target.write(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(source):
|
||||||
|
size = yield from UVarintType.load(source)
|
||||||
|
data = yield from source.read(size)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class UnicodeType:
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(target, value):
|
||||||
|
yield from BytesType.dump(target, bytes(value, 'utf-8'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(source):
|
||||||
|
data = yield from BytesType.load(source)
|
||||||
|
return data.decode('utf-8', 'strict')
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddedMessage:
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
def __init__(self, message_type):
|
||||||
|
'''Initializes a new instance. The argument is an underlying message type.'''
|
||||||
|
self.message_type = message_type
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
'''Creates a message of the underlying message type.'''
|
||||||
|
return self.message_type()
|
||||||
|
|
||||||
|
def dump(self, target, value):
|
||||||
|
buf = self.message_type.dumps(value)
|
||||||
|
yield from BytesType.dump(target, buf)
|
||||||
|
|
||||||
|
def load(self, source, target):
|
||||||
|
emb_size = yield from UVarintType.load(source)
|
||||||
|
emb_source = source.limit(emb_size)
|
||||||
|
yield from self.message_type.load(emb_source, target)
|
||||||
|
|
||||||
|
|
||||||
|
FLAG_SIMPLE = const(0)
|
||||||
|
FLAG_REQUIRED = const(1)
|
||||||
|
FLAG_REPEATED = const(2)
|
||||||
|
|
||||||
|
|
||||||
|
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
||||||
|
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
||||||
|
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
||||||
|
_unpack_key = lambda key: (key >> 3, key & 7)
|
||||||
|
# Determines if a field type is a scalar or not.
|
||||||
|
_is_scalar_type = lambda field_type: not isinstance(
|
||||||
|
field_type, EmbeddedMessage)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBytearrayWriter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.buf = bytearray()
|
||||||
|
|
||||||
|
async def write(self, b):
|
||||||
|
self.buf.extend(b)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType:
|
||||||
|
'''Represents a message type.'''
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
self.__name = name
|
||||||
|
self.__fields = {} # tag -> tuple of field_type, field_flags, field_name
|
||||||
|
self.__defaults = {} # tag -> default_value
|
||||||
|
|
||||||
|
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
|
||||||
|
'''Adds a field to the message type.'''
|
||||||
|
if tag in self.__fields:
|
||||||
|
raise ValueError('The tag %s is already used.' % tag)
|
||||||
|
if default is not None:
|
||||||
|
self.__defaults[tag] = default
|
||||||
|
self.__fields[tag] = (field_type, flags, name)
|
||||||
|
|
||||||
|
def __call__(self, **fields):
|
||||||
|
'''Creates an instance of this message type.'''
|
||||||
|
return Message(self, **fields)
|
||||||
|
|
||||||
|
def dumps(self, value):
|
||||||
|
target = AsyncBytearrayWriter()
|
||||||
|
yield from self.dump(target, value)
|
||||||
|
return target.buf
|
||||||
|
|
||||||
|
def dump(self, target, value):
|
||||||
|
if self is not value.message_type:
|
||||||
|
raise TypeError('Incompatible type')
|
||||||
|
for tag, field in self.__fields.items():
|
||||||
|
field_type, field_flags, field_name = field
|
||||||
|
if field_name not in value.__dict__:
|
||||||
|
if field_flags & FLAG_REQUIRED:
|
||||||
|
raise ValueError(
|
||||||
|
'The field with the tag %s is required but a value is missing.' % tag)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if field_flags & FLAG_REPEATED:
|
||||||
|
# repeated value
|
||||||
|
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||||
|
# send the values sequentially
|
||||||
|
for single_value in getattr(value, field_name):
|
||||||
|
yield from UVarintType.dump(target, key)
|
||||||
|
yield from field_type.dump(target, single_value)
|
||||||
|
else:
|
||||||
|
# single value
|
||||||
|
yield from UVarintType.dump(target, _pack_key(tag, field_type.WIRE_TYPE))
|
||||||
|
yield from field_type.dump(target, getattr(value, field_name))
|
||||||
|
|
||||||
|
def load(self, source, target):
|
||||||
|
found_tags = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
key = yield from UVarintType.load(source)
|
||||||
|
tag, wire_type = _unpack_key(key)
|
||||||
|
found_tags.add(tag)
|
||||||
|
|
||||||
|
if tag in self.__fields:
|
||||||
|
# retrieve the field descriptor by tag
|
||||||
|
field = self.__fields[tag]
|
||||||
|
field_type, _, _ = field
|
||||||
|
if wire_type != field_type.WIRE_TYPE:
|
||||||
|
raise TypeError(
|
||||||
|
'Value of tag %s has incorrect wiretype %s, %s expected.' %
|
||||||
|
(tag, wire_type, field_type.WIRE_TYPE))
|
||||||
|
else:
|
||||||
|
# unknown field, skip it
|
||||||
|
field_type = {0: UVarintType, 2: BytesType}[wire_type]
|
||||||
|
yield from field_type.load(source)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _is_scalar_type(field_type):
|
||||||
|
field_value = yield from field_type.load(source)
|
||||||
|
target.send((field, field_value))
|
||||||
|
else:
|
||||||
|
yield from field_type.load(source, target)
|
||||||
|
|
||||||
|
except EOFError:
|
||||||
|
for tag, field in self.__fields.items():
|
||||||
|
# send the default value
|
||||||
|
if tag not in found_tags and tag in self.__defaults:
|
||||||
|
target.send((field, self.__defaults[tag]))
|
||||||
|
found_tags.add(tag)
|
||||||
|
|
||||||
|
# check if all required fields are present
|
||||||
|
_, field_flags, field_name = field
|
||||||
|
if field_flags & FLAG_REQUIRED and tag not in found_tags:
|
||||||
|
if field_flags & FLAG_REPEATED:
|
||||||
|
# no values were in input stream, but required field.
|
||||||
|
# send empty list
|
||||||
|
target.send((field, []))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'The field %s (\'%s\') is required but missing.' % (tag, field_name))
|
||||||
|
target.throw(EOFError)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<MessageType: %s>' % self.__name
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
'''Represents a message instance.'''
|
||||||
|
|
||||||
|
def __init__(self, message_type, **fields):
|
||||||
|
'''Initializes a new instance of the specified message type.'''
|
||||||
|
self.message_type = message_type
|
||||||
|
for key in fields:
|
||||||
|
setattr(self, key, fields[key])
|
||||||
|
|
||||||
|
def dump(self, target):
|
||||||
|
yield from self.message_type.dump(target, self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
values = self.__dict__
|
||||||
|
values = {k: values[k] for k in values if k != 'message_type'}
|
||||||
|
return '<%s: %s>' % (self.message_type.__name, values)
|
166
src/trezor/wire_streaming.py
Normal file
166
src/trezor/wire_streaming.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import ustruct
|
||||||
|
import ubinascii
|
||||||
|
|
||||||
|
from . import msg
|
||||||
|
from . import loop
|
||||||
|
from .crypto import random
|
||||||
|
|
||||||
|
|
||||||
|
MESSAGE_IFACE = const(0)
|
||||||
|
EMPTY_SESSION = const(0)
|
||||||
|
|
||||||
|
sessions = {}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_id():
|
||||||
|
return random.uniform(0xffffffff) + 1
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_reports():
|
||||||
|
while True:
|
||||||
|
report = await _read_report()
|
||||||
|
session_id, report_data = _parse_report(report)
|
||||||
|
sessions[session_id].send(report_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_session_message(session_id, types):
|
||||||
|
future = loop.Future()
|
||||||
|
pbuf_decoder = _decode_protobuf_message(types, future)
|
||||||
|
wire_decoder = _decode_wire_message(pbuf_decoder)
|
||||||
|
assert session_id not in sessions
|
||||||
|
sessions[session_id] = wire_decoder
|
||||||
|
try:
|
||||||
|
result = await future
|
||||||
|
finally:
|
||||||
|
del sessions[session_id]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_protobuf_type(msg_type, pbuf_types):
|
||||||
|
for pt in pbuf_types:
|
||||||
|
if pt.wire_type == msg_type:
|
||||||
|
return pt
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_protobuf_message(types, future):
|
||||||
|
msg_type, _ = yield
|
||||||
|
pbuf_type = lookup_protobuf_type(msg_type, types)
|
||||||
|
target = build_protobuf_message(pbuf_type, future)
|
||||||
|
yield from pbuf_type.load(AsyncBytearrayReader(), target)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBytearrayReader:
|
||||||
|
|
||||||
|
def __init__(self, buf=None, n=None):
|
||||||
|
self.buf = buf if buf is not None else bytearray()
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def read(self, n):
|
||||||
|
if self.n is not None:
|
||||||
|
self.n -= n
|
||||||
|
if self.n <= 0:
|
||||||
|
raise EOFError()
|
||||||
|
buf = self.buf
|
||||||
|
while len(buf) < n:
|
||||||
|
buf.extend((yield)) # buffer next data chunk
|
||||||
|
result, buf[:] = buf[:n], buf[n:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def limit(self, n):
|
||||||
|
return AsyncBytearrayReader(self.buf, n)
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_report():
|
||||||
|
report, = await loop.Select(MESSAGE_IFACE)
|
||||||
|
return memoryview(report) # make slicing cheap
|
||||||
|
|
||||||
|
|
||||||
|
async def _write_report(report):
|
||||||
|
return msg.send(MESSAGE_IFACE, report)
|
||||||
|
|
||||||
|
|
||||||
|
# TREZOR wire protocol v2:
|
||||||
|
#
|
||||||
|
# HID report (64B):
|
||||||
|
# - report magic (1B)
|
||||||
|
# - session (4B, BE)
|
||||||
|
# - payload (59B)
|
||||||
|
#
|
||||||
|
# message:
|
||||||
|
# - streamed as payloads of HID reports:
|
||||||
|
# - message type (4B, BE)
|
||||||
|
# - data length (4B, BE)
|
||||||
|
# - data (var-length)
|
||||||
|
# - data checksum (4B, BE)
|
||||||
|
|
||||||
|
|
||||||
|
REP_HEADER = '>BL' # marker, session id
|
||||||
|
MSG_HEADER = '>LL' # msg type, data length
|
||||||
|
MSG_FOOTER = '>L' # data checksum
|
||||||
|
|
||||||
|
REP_HEADER_LEN = ustruct.calcsize(REP_HEADER)
|
||||||
|
MSG_HEADER_LEN = ustruct.calcsize(MSG_HEADER)
|
||||||
|
MSG_FOOTER_LEN = ustruct.calcsize(MSG_FOOTER)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageChecksumError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_report(data):
|
||||||
|
marker, session_id = ustruct.parse(REP_HEADER, data)
|
||||||
|
return session_id, data[REP_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_message(data):
|
||||||
|
msg_type, data_len = ustruct.parse(MSG_HEADER, data)
|
||||||
|
return msg_type, data_len, data[MSG_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_footer(data):
|
||||||
|
data_checksum, = ustruct.parse(MSG_FOOTER, data)
|
||||||
|
return data_checksum,
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_wire_message(target):
|
||||||
|
'''Decode a wire message from the report data and stream it to target.
|
||||||
|
|
||||||
|
Receives report payloads.
|
||||||
|
Sends (msg_type, data_len) to target, followed by data chunks.
|
||||||
|
Throws EOFError after last data chunk, in case of valid checksum.
|
||||||
|
Throws MessageChecksumError to target if data doesn't match the checksum.
|
||||||
|
'''
|
||||||
|
message = (yield) # read first report
|
||||||
|
msg_type, data_len, data_tail = _parse_message(message)
|
||||||
|
target.send((msg_type, data_len))
|
||||||
|
|
||||||
|
checksum = 0 # crc32
|
||||||
|
nreports = 1
|
||||||
|
|
||||||
|
while data_len > 0:
|
||||||
|
if nreports > 1:
|
||||||
|
data_tail = (yield) # read next report
|
||||||
|
nreports += 1
|
||||||
|
|
||||||
|
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
||||||
|
data_tail = data_tail[len(data_chunk):] # slice off what we have read
|
||||||
|
data_len -= len(data_chunk)
|
||||||
|
target.send(data_chunk)
|
||||||
|
|
||||||
|
checksum = ubinascii.crc32(checksum, data_chunk)
|
||||||
|
|
||||||
|
data_footer = data_tail[:MSG_FOOTER_LEN]
|
||||||
|
if len(data_footer) < MSG_FOOTER_LEN:
|
||||||
|
data_tail = (yield) # read report with the rest of checksum
|
||||||
|
data_footer += data_tail[:MSG_FOOTER_LEN - len(data_footer)]
|
||||||
|
|
||||||
|
data_checksum, = _parse_footer(data_footer)
|
||||||
|
if data_checksum != checksum:
|
||||||
|
target.throw(MessageChecksumError, 'Message checksum mismatch')
|
||||||
|
else:
|
||||||
|
target.throw(EOFError)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_message(target):
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user