diff --git a/core/embed/extmod/modtrezorutils/modtrezorutils.c b/core/embed/extmod/modtrezorutils/modtrezorutils.c index 81287990a..04e9c2003 100644 --- a/core/embed/extmod/modtrezorutils/modtrezorutils.c +++ b/core/embed/extmod/modtrezorutils/modtrezorutils.c @@ -58,15 +58,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj, mod_trezorutils_consteq); /// def memcpy( -/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int +/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None /// ) -> int: /// """ /// Copies at most `n` bytes from `src` at offset `src_ofs` to -/// `dst` at offset `dst_ofs`. Returns the number of actually -/// copied bytes. +/// `dst` at offset `dst_ofs`. Returns the number of actually +/// copied bytes. If `n` is not specified, tries to copy +/// as much as possible. /// """ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) { - mp_arg_check_num(n_args, 0, 5, 5, false); + mp_arg_check_num(n_args, 0, 4, 5, false); mp_buffer_info_t dst; mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE); @@ -76,7 +77,12 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) { mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ); uint32_t src_ofs = trezor_obj_get_uint(args[3]); - uint32_t n = trezor_obj_get_uint(args[4]); + uint32_t n = 0; + if (n_args > 4) { + n = trezor_obj_get_uint(args[4]); + } else { + n = src.len; + } size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0; size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0; @@ -86,7 +92,7 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) { return mp_obj_new_int(ncpy); } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 5, 5, +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 4, 5, mod_trezorutils_memcpy); /// def halt(msg: str = None) -> None: diff --git a/core/mocks/generated/trezorutils.pyi b/core/mocks/generated/trezorutils.pyi index aed9911c0..26746e2d2 100644 --- a/core/mocks/generated/trezorutils.pyi +++ b/core/mocks/generated/trezorutils.pyi @@ -13,12 +13,13 @@ def consteq(sec: bytes, pub: bytes) -> bool: # extmod/modtrezorutils/modtrezorutils.c def memcpy( - dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int + dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None ) -> int: """ Copies at most `n` bytes from `src` at offset `src_ofs` to - `dst` at offset `dst_ofs`. Returns the number of actually - copied bytes. + `dst` at offset `dst_ofs`. Returns the number of actually + copied bytes. If `n` is not specified, tries to copy + as much as possible. """ diff --git a/core/src/protobuf.py b/core/src/protobuf.py index 74e91603d..ff568d44d 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -9,14 +9,14 @@ if False: from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union from typing_extensions import Protocol - class AsyncReader(Protocol): - async def areadinto(self, buf: bytearray) -> int: + class Reader(Protocol): + def readinto(self, buf: bytearray) -> int: """ Reads `len(buf)` bytes into `buf`, or raises `EOFError`. """ - class AsyncWriter(Protocol): - async def awrite(self, buf: bytes) -> int: + class Writer(Protocol): + def write(self, buf: bytes) -> int: """ Writes all bytes from `buf`, or raises `EOFError`. """ @@ -25,20 +25,20 @@ if False: _UVARINT_BUFFER = bytearray(1) -async def load_uvarint(reader: AsyncReader) -> int: +def load_uvarint(reader: Reader) -> int: buffer = _UVARINT_BUFFER result = 0 shift = 0 byte = 0x80 while byte & 0x80: - await reader.areadinto(buffer) + reader.readinto(buffer) byte = buffer[0] result += (byte & 0x7F) << shift shift += 7 return result -async def dump_uvarint(writer: AsyncWriter, n: int) -> None: +def dump_uvarint(writer: Writer, n: int) -> None: if n < 0: raise ValueError("Cannot dump signed value, convert it to unsigned first.") buffer = _UVARINT_BUFFER @@ -46,7 +46,7 @@ async def dump_uvarint(writer: AsyncWriter, n: int) -> None: while shifted: shifted = n >> 7 buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) - await writer.awrite(buffer) + writer.write(buffer) n = shifted @@ -165,15 +165,15 @@ class MessageType: class LimitedReader: - def __init__(self, reader: AsyncReader, limit: int) -> None: + def __init__(self, reader: Reader, limit: int) -> None: self.reader = reader self.limit = limit - async def areadinto(self, buf: bytearray) -> int: + def readinto(self, buf: bytearray) -> int: if self.limit < len(buf): raise EOFError else: - nread = await self.reader.areadinto(buf) + nread = self.reader.readinto(buf) self.limit -= nread return nread @@ -184,8 +184,8 @@ if False: LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType) -async def load_message( - reader: AsyncReader, msg_type: Type[LoadedMessageType] +def load_message( + reader: Reader, msg_type: Type[LoadedMessageType] ) -> LoadedMessageType: fields = msg_type.get_fields() msg = msg_type() @@ -197,7 +197,7 @@ async def load_message( while True: try: - fkey = await load_uvarint(reader) + fkey = load_uvarint(reader) except EOFError: break # no more fields to load @@ -208,10 +208,10 @@ async def load_message( if field is None: # unknown field, skip it if wtype == 0: - await load_uvarint(reader) + load_uvarint(reader) elif wtype == 2: - ivalue = await load_uvarint(reader) - await reader.areadinto(bytearray(ivalue)) + ivalue = load_uvarint(reader) + reader.readinto(bytearray(ivalue)) else: raise ValueError continue @@ -220,7 +220,7 @@ async def load_message( if wtype != ftype.WIRE_TYPE: raise TypeError # parsed wire type differs from the schema - ivalue = await load_uvarint(reader) + ivalue = load_uvarint(reader) if ftype is UVarintType: fvalue = ivalue @@ -232,13 +232,13 @@ async def load_message( fvalue = ftype.validate(ivalue) elif ftype is BytesType: fvalue = bytearray(ivalue) - await reader.areadinto(fvalue) + reader.readinto(fvalue) elif ftype is UnicodeType: fvalue = bytearray(ivalue) - await reader.areadinto(fvalue) + reader.readinto(fvalue) fvalue = bytes(fvalue).decode() elif issubclass(ftype, MessageType): - fvalue = await load_message(LimitedReader(reader, ivalue), ftype) + fvalue = load_message(LimitedReader(reader, ivalue), ftype) else: raise TypeError # field type is unknown @@ -257,9 +257,7 @@ async def load_message( return msg -async def dump_message( - writer: AsyncWriter, msg: MessageType, fields: Dict = None -) -> None: +def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None: repvalue = [0] if fields is None: @@ -281,39 +279,39 @@ async def dump_message( ffields = None # type: Optional[Dict] for svalue in fvalue: - await dump_uvarint(writer, fkey) + dump_uvarint(writer, fkey) if ftype is UVarintType: - await dump_uvarint(writer, svalue) + dump_uvarint(writer, svalue) elif ftype is SVarintType: - await dump_uvarint(writer, sint_to_uint(svalue)) + dump_uvarint(writer, sint_to_uint(svalue)) elif ftype is BoolType: - await dump_uvarint(writer, int(svalue)) + dump_uvarint(writer, int(svalue)) elif isinstance(ftype, EnumType): - await dump_uvarint(writer, svalue) + dump_uvarint(writer, svalue) elif ftype is BytesType: if isinstance(svalue, list): - await dump_uvarint(writer, _count_bytes_list(svalue)) + dump_uvarint(writer, _count_bytes_list(svalue)) for sub_svalue in svalue: - await writer.awrite(sub_svalue) + writer.write(sub_svalue) else: - await dump_uvarint(writer, len(svalue)) - await writer.awrite(svalue) + dump_uvarint(writer, len(svalue)) + writer.write(svalue) elif ftype is UnicodeType: svalue = svalue.encode() - await dump_uvarint(writer, len(svalue)) - await writer.awrite(svalue) + dump_uvarint(writer, len(svalue)) + writer.write(svalue) elif issubclass(ftype, MessageType): if ffields is None: ffields = ftype.get_fields() - await dump_uvarint(writer, count_message(svalue, ffields)) - await dump_message(writer, svalue, ffields) + dump_uvarint(writer, count_message(svalue, ffields)) + dump_message(writer, svalue, ffields) else: raise TypeError diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index c64927b4e..f79b5cc16 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -138,6 +138,7 @@ class Context: def __init__(self, iface: WireInterface, sid: int) -> None: self.iface = iface self.sid = sid + self.buffer_io = codec_v1.BytesIO(bytearray(8192)) async def call( self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType] @@ -153,11 +154,13 @@ class Context: del msg return await self.read_any(expected_wire_types) + async def read_from_wire(self) -> codec_v1.Message: + self.buffer_io.seek(0) + return await codec_v1.read_message(self.iface, self.buffer_io.buffer) + async def read( self, expected_type: Type[protobuf.LoadedMessageType] ) -> protobuf.LoadedMessageType: - reader = self.make_reader() - if __debug__: log.debug( __name__, @@ -167,14 +170,13 @@ class Context: expected_type, ) - # Wait for the message header, contained in the first report. After - # we receive it, we have a message type to match on. - await reader.aopen() + # Load the full message into a buffer, parse out type and data payload + msg = await self.read_from_wire() - # If we got a message with unexpected type, raise the reader via + # If we got a message with unexpected type, raise the message via # `UnexpectedMessageError` and let the session handler deal with it. - if reader.type != expected_type.MESSAGE_WIRE_TYPE: - raise UnexpectedMessageError(reader) + if msg.type != expected_type.MESSAGE_WIRE_TYPE: + raise UnexpectedMessageError(msg) if __debug__: log.debug( @@ -187,14 +189,13 @@ class Context: workflow.idle_timer.touch() - # parse the message and return it - return await protobuf.load_message(reader, expected_type) + # look up the protobuf class and parse the message + pbtype = messages.get_type(msg.type) + return protobuf.load_message(msg.data, pbtype) async def read_any( self, expected_wire_types: Iterable[int] ) -> protobuf.MessageType: - reader = self.make_reader() - if __debug__: log.debug( __name__, @@ -204,17 +205,16 @@ class Context: expected_wire_types, ) - # Wait for the message header, contained in the first report. After - # we receive it, we have a message type to match on. - await reader.aopen() + # Load the full message into a buffer, parse out type and data payload + msg = await self.read_from_wire() - # If we got a message with unexpected type, raise the reader via + # If we got a message with unexpected type, raise the message via # `UnexpectedMessageError` and let the session handler deal with it. - if reader.type not in expected_wire_types: - raise UnexpectedMessageError(reader) + if msg.type not in expected_wire_types: + raise UnexpectedMessageError(msg) # find the protobuf type - exptype = messages.get_type(reader.type) + exptype = messages.get_type(msg.type) if __debug__: log.debug( @@ -224,24 +224,20 @@ class Context: workflow.idle_timer.touch() # parse the message and return it - return await protobuf.load_message(reader, exptype) + return protobuf.load_message(msg.data, exptype) async def write(self, msg: protobuf.MessageType) -> None: - writer = self.make_writer() - if __debug__: log.debug( __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg ) - # get the message size - fields = msg.get_fields() - size = protobuf.count_message(msg, fields) - # write the message - writer.setheader(msg.MESSAGE_WIRE_TYPE, size) - await protobuf.dump_message(writer, msg, fields) - await writer.aclose() + self.buffer_io.seek(0) + protobuf.dump_message(self.buffer_io, msg) + await codec_v1.write_message( + self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written() + ) def wait(self, *tasks: Awaitable) -> Any: """ @@ -251,43 +247,35 @@ class Context: """ return loop.race(self.read_any(()), *tasks) - def make_reader(self) -> codec_v1.Reader: - return codec_v1.Reader(self.iface) - - def make_writer(self) -> codec_v1.Writer: - return codec_v1.Writer(self.iface) - class UnexpectedMessageError(Exception): - def __init__(self, reader: codec_v1.Reader) -> None: - self.reader = reader + def __init__(self, msg: codec_v1.Message) -> None: + self.msg = msg async def handle_session( iface: WireInterface, session_id: int, use_workflow: bool = True ) -> None: ctx = Context(iface, session_id) - next_reader = None # type: Optional[codec_v1.Reader] + next_msg = None # type: Optional[codec_v1.Message] res_msg = None # type: Optional[protobuf.MessageType] - req_reader = None req_type = None req_msg = None while True: try: - if next_reader is None: + if next_msg is None: # We are not currently reading a message, so let's wait for one. # If the decoding fails, exception is raised and we try again # (with the same `Reader` instance, it's OK). Even in case of # de-synchronized wire communication, report with a message # header is eventually received, after a couple of tries. - req_reader = ctx.make_reader() - await req_reader.aopen() + msg = await ctx.read_from_wire() if __debug__: try: - msg_type = messages.get_type(req_reader.type).__name__ + msg_type = messages.get_type(msg.type).__name__ except KeyError: - msg_type = "%d - unknown message type" % req_reader.type + msg_type = "%d - unknown message type" % msg.type log.debug( __name__, "%s:%x receive: <%s>", @@ -298,8 +286,8 @@ async def handle_session( else: # We have a reader left over from earlier. We should process # this message instead of waiting for new one. - req_reader = next_reader - next_reader = None + msg = next_msg + next_msg = None # Now we are in a middle of reading a message and we need to decide # what to do with it, based on its type from the message header. @@ -312,13 +300,11 @@ async def handle_session( # We need to find a handler for this message type. Should not # raise. - handler = find_handler(iface, req_reader.type) + handler = find_handler(iface, msg.type) if handler is None: # If no handler is found, we can skip decoding and directly - # respond with failure, but first, we should read the rest of - # the message reports. Should not raise. - await read_and_throw_away(req_reader) + # respond with failure. Should not raise. res_msg = unexpected_message() else: @@ -332,11 +318,11 @@ async def handle_session( try: # Find a protobuf.MessageType subclass that describes this # message. Raises if the type is not found. - req_type = messages.get_type(req_reader.type) + req_type = messages.get_type(msg.type) # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed. - req_msg = await protobuf.load_message(req_reader, req_type) + req_msg = protobuf.load_message(msg.data, req_type) # At this point, message reports are all processed and # correctly parsed into `req_msg`. @@ -364,7 +350,7 @@ async def handle_session( # TODO: # We might handle only the few common cases here, like # Initialize and Cancel. - next_reader = exc.reader + next_msg = exc.msg res_msg = None except Exception as exc: @@ -401,7 +387,6 @@ async def handle_session( # Cleanup, so garbage collection triggered after un-importing can # pick up the trash. - req_reader = None req_type = None req_msg = None res_msg = None diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 8a92a5aca..af0526d2c 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -18,143 +18,112 @@ SESSION_ID = const(0) INVALID_TYPE = const(-1) -class Reader: - """ - Decoder for a wire codec over the HID (or UDP) layer. Provides readable - async-file-like interface. - """ - - def __init__(self, iface: WireInterface) -> None: - self.iface = iface - self.type = INVALID_TYPE - self.size = 0 - self.ofs = 0 - self.data = bytes() - - def __repr__(self) -> str: - return "" % self.type - - async def aopen(self) -> None: - """ - Start reading a message by waiting for initial message report. Because - the first report contains the message header, `self.type` and - `self.size` are initialized and available after `aopen()` returns. - """ - read = loop.wait(self.iface.iface_num() | io.POLL_READ) - while True: - # wait for initial report - report = await read - marker = report[0] - if marker == _REP_MARKER: - _, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report) - if m1 != _REP_MAGIC or m2 != _REP_MAGIC: - raise ValueError - break +class CodecError(Exception): + pass - # load received message header - self.type = mtype - self.size = msize - self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize] - self.ofs = 0 - - async def areadinto(self, buf: bytearray) -> int: - """ - Read exactly `len(buf)` bytes into `buf`, waiting for additional - reports, if needed. Raises `EOFError` if end-of-message is encountered - before the full read can be completed. - """ - if self.size < len(buf): - raise EOFError - read = loop.wait(self.iface.iface_num() | io.POLL_READ) - nread = 0 - while nread < len(buf): - if self.ofs == len(self.data): - # we are at the end of received data - # wait for continuation report - while True: - report = await read - marker = report[0] - if marker == _REP_MARKER: - break - self.data = report[_REP_CONT_DATA : _REP_CONT_DATA + self.size] - self.ofs = 0 - - # copy as much as possible to target buffer - nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf)) - nread += nbytes - self.ofs += nbytes - self.size -= nbytes +class BytesIO: + def __init__(self, buffer: bytearray) -> None: + self.buffer = buffer + self.offset = 0 + def seek(self, offset: int) -> None: + offset = min(offset, len(self.buffer)) + offset = max(offset, 0) + self.offset = offset + + def readinto(self, dst: bytearray) -> int: + buffer = self.buffer + offset = self.offset + if len(dst) > len(buffer) - offset: + raise EOFError + nread = utils.memcpy(dst, 0, buffer, offset) + self.offset += nread return nread + def write(self, src: bytes) -> int: + buffer = self.buffer + offset = self.offset + if len(src) > len(buffer) - offset: + raise EOFError + nwrite = utils.memcpy(buffer, offset, src, 0) + self.offset += nwrite + return nwrite + + def get_written(self) -> bytes: + return memoryview(self.buffer)[: self.offset] + -class Writer: - """ - Encoder for a wire codec over the HID (or UDP) layer. Provides writable - async-file-like interface. - """ - - def __init__(self, iface: WireInterface): - self.iface = iface - self.type = INVALID_TYPE - self.size = 0 - self.ofs = 0 - self.data = bytearray(_REP_LEN) - - def setheader(self, mtype: int, msize: int) -> None: - """ - Reset the writer state and load the message header with passed type and - total message size. - """ +class Message: + def __init__(self, mtype: int, mdata: BytesIO) -> None: self.type = mtype - self.size = msize - ustruct.pack_into( - _REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize - ) - self.ofs = _REP_INIT_DATA - - async def awrite(self, buf: bytes) -> int: - """ - Encode and write every byte from `buf`. Does not need to be called in - case message has zero length. Raises `EOFError` if the length of `buf` - exceeds the remaining message length. - """ - if self.size < len(buf): - raise EOFError + self.data = mdata + + +async def read_message(iface: WireInterface, buffer: bytearray) -> Message: + read = loop.wait(iface.iface_num() | io.POLL_READ) + + # wait for initial report + report = await read + if report[0] != _REP_MARKER: + raise CodecError("Invalid magic") + _, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report) + if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC: + raise CodecError("Invalid magic") + + throw_away = False + if msize > len(buffer): + throw_away = True + + # prepare the backing buffer + mdata = memoryview(buffer)[:msize] + + # buffer the initial data + nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA) + + while nread < msize: + # wait for continuation report + report = await read + if report[0] != _REP_MARKER: + raise CodecError("Invalid magic") + + # buffer the continuation data + if not throw_away: + nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) + + if throw_away: + raise CodecError("Message too large") + + return Message(mtype, BytesIO(mdata)) + + +async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None: + write = loop.wait(iface.iface_num() | io.POLL_WRITE) + + # gather data from msg + msize = len(mdata) + + # prepare the report buffer with header data + report = bytearray(_REP_LEN) + repofs = _REP_INIT_DATA + ustruct.pack_into( + _REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize + ) + + nwritten = 0 + while True: + # copy as much as possible to the report buffer + nwritten += utils.memcpy(report, repofs, mdata, nwritten) + + # write the report + while True: + await write + n = iface.write(report) + if n == len(report): + break - write = loop.wait(self.iface.iface_num() | io.POLL_WRITE) - nwritten = 0 - while nwritten < len(buf): - # copy as much as possible to report buffer - nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf)) - nwritten += nbytes - self.ofs += nbytes - self.size -= nbytes - - if self.ofs == _REP_LEN: - # we are at the end of the report, flush it - while True: - await write - n = self.iface.write(self.data) - if n == len(self.data): - break - self.ofs = _REP_CONT_DATA - - return nwritten - - async def aclose(self) -> None: - """Flush and close the message transmission.""" - if self.ofs != _REP_CONT_DATA: - # we didn't write anything or last write() wasn't report-aligned, - # pad the final report and flush it - while self.ofs < _REP_LEN: - self.data[self.ofs] = 0x00 - self.ofs += 1 - - write = loop.wait(self.iface.iface_num() | io.POLL_WRITE) - while True: - await write - n = self.iface.write(self.data) - if n == len(self.data): - break + # if we have more data to write, use continuation reports for it + if nwritten < msize: + repofs = _REP_CONT_DATA + else: + break