diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 16b8ee966..e3b73efa3 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -145,6 +145,7 @@ async def handle_DoPreauthorized( req = await ctx.call_any(PreauthorizedRequest(), *wire_types) + assert req.MESSAGE_WIRE_TYPE is not None handler = workflow_handlers.find_registered_handler( ctx.iface, req.MESSAGE_WIRE_TYPE ) diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 4a5fc7502..156f821b2 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -1,7 +1,8 @@ import protobuf import storage.cache -from trezor import messages, utils -from trezor.messages import MessageType +from trezor import protobuf +from trezor.enums import MessageType +from trezor.utils import ensure if False: from typing import Iterable @@ -16,9 +17,12 @@ def is_set() -> bool: def set(auth_message: protobuf.MessageType) -> None: - buffer = bytearray(protobuf.count_message(auth_message)) - writer = utils.BufferWriter(buffer) - protobuf.dump_message(writer, auth_message) + buffer = protobuf.dump_message_buffer(auth_message) + + # only wire-level messages can be stored as authorization + # (because only wire-level messages have wire_type, which we use as identifier) + ensure(auth_message.MESSAGE_WIRE_TYPE is not None) + assert auth_message.MESSAGE_WIRE_TYPE is not None # so that mypy knows as well storage.cache.set( storage.cache.APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"), @@ -32,11 +36,8 @@ def get() -> protobuf.MessageType | None: return None msg_wire_type = int.from_bytes(stored_auth_type, "big") - msg_type = messages.get_type(msg_wire_type) buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA) - reader = utils.BufferReader(buffer) - - return protobuf.load_message(reader, msg_type) + return protobuf.load_message_buffer(buffer, msg_wire_type) def get_wire_types() -> Iterable[int]: diff --git a/core/src/apps/monero/get_tx_keys.py b/core/src/apps/monero/get_tx_keys.py index 3c960fd2b..748fdebc4 100644 --- a/core/src/apps/monero/get_tx_keys.py +++ b/core/src/apps/monero/get_tx_keys.py @@ -50,7 +50,7 @@ async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain): # and then is used to store the derivations if applicable plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys) utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size") - del msg.tx_enc_keys + msg.tx_enc_keys = b"" # If return only derivations do tx_priv * view_pub if do_deriv: diff --git a/core/src/apps/monero/signing/offloading_keys.py b/core/src/apps/monero/signing/offloading_keys.py index 9b7449844..8fee2a8a4 100644 --- a/core/src/apps/monero/signing/offloading_keys.py +++ b/core/src/apps/monero/signing/offloading_keys.py @@ -134,7 +134,7 @@ def gen_hmac_vini( used only once and hard to check. I.e., indices in step 2 are uncheckable, decoy keys in step 9 are just random keys. """ - import protobuf + from trezor import protobuf from apps.monero.xmr.keccak_hasher import get_keccak_writer kwriter = get_keccak_writer() @@ -146,7 +146,7 @@ def gen_hmac_vini( src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index] ] - protobuf.dump_message(kwriter, src_entr) + kwriter.write(protobuf.dump_message_buffer(src_entr)) src_entr.outputs = real_outputs src_entr.real_out_additional_tx_keys = real_additional kwriter.write(vini_bin) @@ -162,11 +162,11 @@ def gen_hmac_vouti( """ Generates HMAC for (TxDestinationEntry[i] || tx.vout[i]) """ - import protobuf + from trezor import protobuf from apps.monero.xmr.keccak_hasher import get_keccak_writer kwriter = get_keccak_writer() - protobuf.dump_message(kwriter, dst_entr) + kwriter.write(protobuf.dump_message_buffer(dst_entr)) kwriter.write(tx_out_bin) hmac_key_vouti = hmac_key_txout(key, idx) @@ -180,11 +180,11 @@ def gen_hmac_tsxdest( """ Generates HMAC for TxDestinationEntry[i] """ - import protobuf + from trezor import protobuf from apps.monero.xmr.keccak_hasher import get_keccak_writer kwriter = get_keccak_writer() - protobuf.dump_message(kwriter, dst_entr) + kwriter.write(protobuf.dump_message_buffer(dst_entr)) hmac_key = hmac_key_txdst(key, idx) hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest()) diff --git a/core/src/apps/monero/signing/step_01_init_transaction.py b/core/src/apps/monero/signing/step_01_init_transaction.py index 4a621c48b..87d0a37c1 100644 --- a/core/src/apps/monero/signing/step_01_init_transaction.py +++ b/core/src/apps/monero/signing/step_01_init_transaction.py @@ -116,7 +116,7 @@ async def init_transaction( from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData - rsig_data = MoneroTransactionRsigData(offload_type=state.rsig_offload) + rsig_data = MoneroTransactionRsigData(offload_type=int(state.rsig_offload)) return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data) @@ -273,11 +273,11 @@ def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData): """ Generate master key H( H(TsxData || tx_priv) || rand ) """ - import protobuf + from trezor import protobuf from apps.monero.xmr.keccak_hasher import get_keccak_writer writer = get_keccak_writer() - protobuf.dump_message(writer, tsx_data) + writer.write(protobuf.dump_message_buffer(tsx_data)) writer.write(crypto.encodeint(state.tx_priv)) master_key = crypto.keccak_2hash( diff --git a/core/src/apps/monero/xmr/serialize/readwriter.py b/core/src/apps/monero/xmr/serialize/readwriter.py index 54cd8ddaa..67eec2496 100644 --- a/core/src/apps/monero/xmr/serialize/readwriter.py +++ b/core/src/apps/monero/xmr/serialize/readwriter.py @@ -62,9 +62,6 @@ class MemoryReaderWriter: return nread - async def areadinto(self, buf): - return self.readinto(buf) - def write(self, buf): nwritten = len(buf) nall = len(self.buffer) @@ -98,9 +95,6 @@ class MemoryReaderWriter: self.ndata += nwritten return nwritten - async def awrite(self, buf): - return self.write(buf) - def get_buffer(self): mv = memoryview(self.buffer) return mv[self.offset : self.woffset] diff --git a/core/src/protobuf.py b/core/src/protobuf.py deleted file mode 100644 index 5e879b329..000000000 --- a/core/src/protobuf.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -Extremely minimal streaming codec for a subset of protobuf. Supports uint32, -bytes, string, embedded message and repeated fields. -""" - -if False: - from typing import ( - Any, - Callable, - Iterable, - TypeVar, - Union, - ) - from typing_extensions import Protocol - - class Reader(Protocol): - def readinto(self, buf: bytearray) -> int: - """ - Reads `len(buf)` bytes into `buf`, or raises `EOFError`. - """ - - class Writer(Protocol): - def write(self, buf: bytes) -> int: - """ - Writes all bytes from `buf`, or raises `EOFError`. - """ - - WriteMethod = Callable[[bytes], Any] - - -_UVARINT_BUFFER = bytearray(1) - - -def load_uvarint(reader: Reader) -> int: - buffer = _UVARINT_BUFFER - result = 0 - shift = 0 - byte = 0x80 - while byte & 0x80: - reader.readinto(buffer) - byte = buffer[0] - result += (byte & 0x7F) << shift - shift += 7 - return result - - -def dump_uvarint(write: WriteMethod, n: int) -> None: - if n < 0: - raise ValueError("Cannot dump signed value, convert it to unsigned first.") - buffer = _UVARINT_BUFFER - shifted = 1 - while shifted: - shifted = n >> 7 - buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) - write(buffer) - n = shifted - - -def count_uvarint(n: int) -> int: - if n < 0: - raise ValueError("Cannot dump signed value, convert it to unsigned first.") - if n <= 0x7F: - return 1 - if n <= 0x3FFF: - return 2 - if n <= 0x1F_FFFF: - return 3 - if n <= 0xFFF_FFFF: - return 4 - if n <= 0x7_FFFF_FFFF: - return 5 - if n <= 0x3FF_FFFF_FFFF: - return 6 - if n <= 0x1_FFFF_FFFF_FFFF: - return 7 - if n <= 0xFF_FFFF_FFFF_FFFF: - return 8 - if n <= 0x7FFF_FFFF_FFFF_FFFF: - return 9 - raise ValueError - - -# protobuf interleaved signed encoding: -# https://developers.google.com/protocol-buffers/docs/encoding#structure -# the idea is to save the sign in LSbit instead of twos-complement. -# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips) -# -# To achieve this with a twos-complement number: -# 1. shift left by 1, leaving LSbit free -# 2. if the number is negative, do bitwise negation. -# This keeps positive number the same, and converts negative from twos-complement -# to the appropriate value, while setting the sign bit. -# -# The original algorithm makes use of the fact that arithmetic (signed) shift -# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits". -# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive -# and XOR 1 (bitwise negation) for negative. Cute and efficient. -# -# But this is harder in Python because we don't natively know the bit size of the number. -# So we have to branch on whether the number is negative. - - -def sint_to_uint(sint: int) -> int: - res = sint << 1 - if sint < 0: - res = ~res - return res - - -def uint_to_sint(uint: int) -> int: - sign = uint & 1 - res = uint >> 1 - if sign: - res = ~res - return res - - -class UVarintType: - WIRE_TYPE = 0 - - -class SVarintType: - WIRE_TYPE = 0 - - -class BoolType: - WIRE_TYPE = 0 - - -class EnumType: - WIRE_TYPE = 0 - - def __init__(self, name: str, enum_values: Iterable[int]) -> None: - self.enum_values = enum_values - - def validate(self, fvalue: int) -> int: - if fvalue in self.enum_values: - return fvalue - else: - raise TypeError("Invalid enum value") - - -class BytesType: - WIRE_TYPE = 2 - - -class UnicodeType: - WIRE_TYPE = 2 - - -if False: - MessageTypeDef = Union[ - type[UVarintType], - type[SVarintType], - type[BoolType], - EnumType, - type[BytesType], - type[UnicodeType], - type["MessageType"], - ] - FieldDef = tuple[str, MessageTypeDef, Any] - FieldDict = dict[int, FieldDef] - - FieldCache = dict[type["MessageType"], FieldDict] - - LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType") - - -class MessageType: - WIRE_TYPE = 2 - UNSTABLE = False - - # Type id for the wire codec. - # Technically, not every protobuf message has this. - MESSAGE_WIRE_TYPE = -1 - - @classmethod - def get_fields(cls) -> FieldDict: - return {} - - @classmethod - def cache_subordinate_types(cls, field_cache: FieldCache) -> None: - if cls in field_cache: - fields = field_cache[cls] - else: - fields = cls.get_fields() - field_cache[cls] = fields - - for _, field_type, _ in fields.values(): - if isinstance(field_type, MessageType): - field_type.cache_subordinate_types(field_cache) - - def __eq__(self, rhs: Any) -> bool: - return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ - - def __repr__(self) -> str: - return "<%s>" % self.__class__.__name__ - - -class LimitedReader: - def __init__(self, reader: Reader, limit: int) -> None: - self.reader = reader - self.limit = limit - - def readinto(self, buf: bytearray) -> int: - if self.limit < len(buf): - raise EOFError - else: - nread = self.reader.readinto(buf) - self.limit -= nread - return nread - - -FLAG_REPEATED = object() -FLAG_REQUIRED = object() -FLAG_EXPERIMENTAL = object() - - -def load_message( - reader: Reader, - msg_type: type[LoadedMessageType], - field_cache: FieldCache | None = None, - experimental_enabled: bool = True, -) -> LoadedMessageType: - if field_cache is None: - field_cache = {} - fields = field_cache.get(msg_type) - if fields is None: - fields = msg_type.get_fields() - field_cache[msg_type] = fields - - if msg_type.UNSTABLE and not experimental_enabled: - raise ValueError # experimental messages not enabled - - # we need to avoid calling __init__, which enforces required arguments - msg: LoadedMessageType = object.__new__(msg_type) - # pre-seed the object with defaults - for fname, _, fdefault in fields.values(): - if fdefault is FLAG_REPEATED: - fdefault = [] - elif fdefault is FLAG_EXPERIMENTAL: - fdefault = None - setattr(msg, fname, fdefault) - - if False: - SingularValue = Union[int, bool, bytearray, str, MessageType] - Value = Union[SingularValue, list[SingularValue]] - fvalue: Value = 0 - - while True: - try: - fkey = load_uvarint(reader) - except EOFError: - break # no more fields to load - - ftag = fkey >> 3 - wtype = fkey & 7 - - field = fields.get(ftag, None) - - if field is None: # unknown field, skip it - if wtype == 0: - load_uvarint(reader) - elif wtype == 2: - ivalue = load_uvarint(reader) - reader.readinto(bytearray(ivalue)) - else: - raise ValueError - continue - - fname, ftype, fdefault = field - if wtype != ftype.WIRE_TYPE: - raise TypeError # parsed wire type differs from the schema - - if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled: - raise ValueError # experimental fields not enabled - - ivalue = load_uvarint(reader) - - if ftype is UVarintType: - fvalue = ivalue - elif ftype is SVarintType: - fvalue = uint_to_sint(ivalue) - elif ftype is BoolType: - fvalue = bool(ivalue) - elif isinstance(ftype, EnumType): - fvalue = ftype.validate(ivalue) - elif ftype is BytesType: - fvalue = bytearray(ivalue) - reader.readinto(fvalue) - elif ftype is UnicodeType: - fvalue = bytearray(ivalue) - reader.readinto(fvalue) - fvalue = bytes(fvalue).decode() - elif issubclass(ftype, MessageType): - fvalue = load_message( - LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled - ) - else: - raise TypeError # field type is unknown - - if fdefault is FLAG_REPEATED: - getattr(msg, fname).append(fvalue) - else: - setattr(msg, fname, fvalue) - - for fname, _, _ in fields.values(): - if getattr(msg, fname) is FLAG_REQUIRED: - # The message is intended to be user-facing when decoding from wire, - # but not when used internally. - raise ValueError("Required field '{}' was not received".format(fname)) - - return msg - - -def dump_message( - writer: Writer, msg: MessageType, field_cache: FieldCache | None = None -) -> None: - repvalue = [0] - - if field_cache is None: - field_cache = {} - fields = field_cache.get(type(msg)) - if fields is None: - fields = msg.get_fields() - field_cache[type(msg)] = fields - - for ftag in fields: - fname, ftype, fdefault = fields[ftag] - - fvalue = getattr(msg, fname, None) - if fvalue is None: - continue - - fkey = (ftag << 3) | ftype.WIRE_TYPE - - if fdefault is not FLAG_REPEATED: - repvalue[0] = fvalue - fvalue = repvalue - - for svalue in fvalue: - dump_uvarint(writer.write, fkey) - - if ftype is UVarintType: - dump_uvarint(writer.write, svalue) - - elif ftype is SVarintType: - dump_uvarint(writer.write, sint_to_uint(svalue)) - - elif ftype is BoolType: - dump_uvarint(writer.write, int(svalue)) - - elif isinstance(ftype, EnumType): - dump_uvarint(writer.write, svalue) - - elif ftype is BytesType: - if isinstance(svalue, list): - dump_uvarint(writer.write, _count_bytes_list(svalue)) - for sub_svalue in svalue: - writer.write(sub_svalue) - else: - dump_uvarint(writer.write, len(svalue)) - writer.write(svalue) - - elif ftype is UnicodeType: - svalue = svalue.encode() - dump_uvarint(writer.write, len(svalue)) - writer.write(svalue) - - elif issubclass(ftype, MessageType): - ffields = field_cache.get(ftype) - if ffields is None: - ffields = ftype.get_fields() - field_cache[ftype] = ffields - dump_uvarint(writer.write, count_message(svalue, field_cache)) - dump_message(writer, svalue, field_cache) - - else: - raise TypeError - - -def count_message(msg: MessageType, field_cache: FieldCache | None = None) -> int: - nbytes = 0 - repvalue = [0] - - if field_cache is None: - field_cache = {} - fields = field_cache.get(type(msg)) - if fields is None: - fields = msg.get_fields() - field_cache[type(msg)] = fields - - for ftag in fields: - fname, ftype, fdefault = fields[ftag] - - fvalue = getattr(msg, fname, None) - if fvalue is None: - continue - - fkey = (ftag << 3) | ftype.WIRE_TYPE - - if fdefault is not FLAG_REPEATED: - repvalue[0] = fvalue - fvalue = repvalue - - # length of all the field keys - nbytes += count_uvarint(fkey) * len(fvalue) - - if ftype is UVarintType: - for svalue in fvalue: - nbytes += count_uvarint(svalue) - - elif ftype is SVarintType: - for svalue in fvalue: - nbytes += count_uvarint(sint_to_uint(svalue)) - - elif ftype is BoolType: - for svalue in fvalue: - nbytes += count_uvarint(int(svalue)) - - elif isinstance(ftype, EnumType): - for svalue in fvalue: - nbytes += count_uvarint(svalue) - - elif ftype is BytesType: - for svalue in fvalue: - if isinstance(svalue, list): - svalue = _count_bytes_list(svalue) - else: - svalue = len(svalue) - nbytes += count_uvarint(svalue) - nbytes += svalue - - elif ftype is UnicodeType: - for svalue in fvalue: - svalue = len(svalue.encode()) - nbytes += count_uvarint(svalue) - nbytes += svalue - - elif issubclass(ftype, MessageType): - for svalue in fvalue: - fsize = count_message(svalue, field_cache) - nbytes += count_uvarint(fsize) - nbytes += fsize - - else: - raise TypeError - - return nbytes - - -def _count_bytes_list(svalue: list[bytes]) -> int: - res = 0 - for x in svalue: - res += len(x) - return res diff --git a/core/src/trezor/protobuf.py b/core/src/trezor/protobuf.py new file mode 100644 index 000000000..6fc9b82b9 --- /dev/null +++ b/core/src/trezor/protobuf.py @@ -0,0 +1,42 @@ +import trezorproto + +decode = trezorproto.decode +encode = trezorproto.encode +encoded_length = trezorproto.encoded_length +type_for_name = trezorproto.type_for_name +type_for_wire = trezorproto.type_for_wire + +# XXX +# Note that MessageType "subclasses" are not true subclasses, but instead instances +# of the built-in metaclass MsgDef. MessageType instances are in fact instances of +# the built-in type Msg. That is why isinstance checks do not work, and instead the +# MessageTypeSubclass.is_type_of() method must be used. +if False: + from typing import Type, TypeGuard, TypeVar + + T = TypeVar("T", bound="MessageType") + + class MsgDef(type): + @classmethod + def is_type_of(cls: Type[Type[T]], msg: "MessageType") -> TypeGuard[T]: + """Identify if the provided message belongs to this type.""" + raise NotImplementedError + + class MessageType(metaclass=MsgDef): + MESSAGE_NAME: str = "MessageType" + MESSAGE_WIRE_TYPE: int | None = None + + +def load_message_buffer( + buffer: bytes, + msg_wire_type: int, + experimental_enabled: bool = True, +) -> MessageType: + msg_type = type_for_wire(msg_wire_type) + return decode(buffer, msg_type, experimental_enabled) + + +def dump_message_buffer(msg: MessageType) -> bytearray: + buffer = bytearray(encoded_length(msg)) + encode(buffer, msg) + return buffer diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 26039f6ba..2f2852359 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -145,10 +145,6 @@ class HashWriter: def write(self, buf: bytes) -> None: # alias for extend() self.ctx.update(buf) - async def awrite(self, buf: bytes) -> int: # AsyncWriter interface - self.ctx.update(buf) - return len(buf) - def get_digest(self) -> bytes: return self.ctx.digest() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 040ff5095..ca67810cd 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -35,11 +35,10 @@ reads the message's header. When the message type is known the first handler is """ -import protobuf from storage.cache import InvalidSessionError -from trezor import log, loop, messages, utils, workflow -from trezor.messages import FailureType -from trezor.messages.Failure import Failure +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor import log, loop, protobuf, utils, workflow from trezor.wire import codec_v1 from trezor.wire.errors import ActionCancelled, DataError, Error @@ -74,17 +73,19 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None: if False: - from typing import Protocol + from typing import Protocol, TypeVar + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) class GenericContext(Protocol): async def call( self, msg: protobuf.MessageType, - expected_type: type[protobuf.LoadedMessageType], + expected_type: type[protobuf.MessageType], ) -> Any: ... - async def read(self, expected_type: type[protobuf.LoadedMessageType]) -> Any: + async def read(self, expected_type: type[protobuf.MessageType]) -> Any: ... async def write(self, msg: protobuf.MessageType) -> None: @@ -96,15 +97,14 @@ if False: def _wrap_protobuf_load( - reader: protobuf.Reader, - expected_type: type[protobuf.LoadedMessageType], - field_cache: protobuf.FieldCache | None = None, -) -> protobuf.LoadedMessageType: + buffer: bytes, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: try: - return protobuf.load_message( - reader, expected_type, field_cache, experimental_enabled - ) + return protobuf.decode(buffer, expected_type, experimental_enabled) except Exception as e: + if __debug__: + log.exception(__name__, e) if e.args: raise DataError("Failed to decode message: {}".format(e.args[0])) else: @@ -139,19 +139,15 @@ class Context: self.iface = iface self.sid = sid self.buffer = buffer - self.buffer_writer = utils.BufferWriter(self.buffer) - - self._field_cache: protobuf.FieldCache = {} async def call( self, msg: protobuf.MessageType, - expected_type: type[protobuf.LoadedMessageType], - field_cache: protobuf.FieldCache | None = None, - ) -> protobuf.LoadedMessageType: - await self.write(msg, field_cache) + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + await self.write(msg) del msg - return await self.read(expected_type, field_cache) + return await self.read(expected_type) async def call_any( self, msg: protobuf.MessageType, *expected_wire_types: int @@ -160,21 +156,17 @@ class Context: del msg return await self.read_any(expected_wire_types) - async def read_from_wire(self) -> codec_v1.Message: - return await codec_v1.read_message(self.iface, self.buffer) + def read_from_wire(self) -> Awaitable[codec_v1.Message]: + return codec_v1.read_message(self.iface, self.buffer) - async def read( - self, - expected_type: type[protobuf.LoadedMessageType], - field_cache: protobuf.FieldCache | None = None, - ) -> protobuf.LoadedMessageType: + async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType: if __debug__: log.debug( __name__, "%s:%x expect: %s", self.iface.iface_num(), self.sid, - expected_type, + expected_type.MESSAGE_NAME, ) # Load the full message into a buffer, parse out type and data payload @@ -191,13 +183,13 @@ class Context: "%s:%x read: %s", self.iface.iface_num(), self.sid, - expected_type, + expected_type.MESSAGE_NAME, ) workflow.idle_timer.touch() # look up the protobuf class and parse the message - return _wrap_protobuf_load(msg.data, expected_type, field_cache) + return _wrap_protobuf_load(msg.data, expected_type) async def read_any( self, expected_wire_types: Iterable[int] @@ -220,11 +212,15 @@ class Context: raise UnexpectedMessageError(msg) # find the protobuf type - exptype = messages.get_type(msg.type) + exptype = protobuf.type_for_wire(msg.type) if __debug__: log.debug( - __name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype + __name__, + "%s:%x read: %s", + self.iface.iface_num(), + self.sid, + exptype.MESSAGE_NAME, ) workflow.idle_timer.touch() @@ -232,41 +228,36 @@ class Context: # parse the message and return it return _wrap_protobuf_load(msg.data, exptype) - async def write( - self, - msg: protobuf.MessageType, - field_cache: protobuf.FieldCache | None = None, - ) -> None: + async def write(self, msg: protobuf.MessageType) -> None: if __debug__: log.debug( - __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg + __name__, + "%s:%x write: %s", + self.iface.iface_num(), + self.sid, + msg.MESSAGE_NAME, ) - if field_cache is None: - field_cache = self._field_cache + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None - # write the message - msg_size = protobuf.count_message(msg, field_cache) + msg_size = protobuf.encoded_length(msg) - # prepare buffer - if msg_size <= len(self.buffer_writer.buffer): + if msg_size <= len(self.buffer): # reuse preallocated - buffer_writer = self.buffer_writer + buffer = self.buffer else: # message is too big, we need to allocate a new buffer - buffer_writer = utils.BufferWriter(bytearray(msg_size)) + buffer = bytearray(msg_size) + + msg_size = protobuf.encode(buffer, msg) - buffer_writer.seek(0) - protobuf.dump_message(buffer_writer, msg, field_cache) await codec_v1.write_message( self.iface, msg.MESSAGE_WIRE_TYPE, - memoryview(buffer_writer.buffer)[:msg_size], + memoryview(buffer)[:msg_size], ) - # make sure we don't keep around fields of all protobuf types ever - self._field_cache.clear() - def wait(self, *tasks: Awaitable) -> Any: """ Wait until one of the passed tasks finishes, and return the result, @@ -301,8 +292,8 @@ async def _handle_single_message( """ if __debug__: try: - msg_type = messages.get_type(msg.type).__name__ - except KeyError: + msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + except Exception: msg_type = "%d - unknown message type" % msg.type log.debug( __name__, @@ -328,7 +319,7 @@ async def _handle_single_message( try: # Find a protobuf.MessageType subclass that describes this # message. Raises if the type is not found. - req_type = messages.get_type(msg.type) + req_type = protobuf.type_for_wire(msg.type) # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed. @@ -424,6 +415,11 @@ async def handle_session( next_msg = await _handle_single_message( ctx, msg, use_workflow=not is_debug_session ) + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) finally: if not __debug__ or not is_debug_session: # Unload modules imported by the workflow. Should not raise. diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 67c89c070..153704aa2 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -23,7 +23,7 @@ class CodecError(Exception): class Message: - def __init__(self, mtype: int, mdata: utils.BufferReader) -> None: + def __init__(self, mtype: int, mdata: bytes) -> None: self.type = mtype self.data = mdata @@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if read_and_throw_away: raise CodecError("Message too large") - return Message(mtype, utils.BufferReader(mdata)) + return Message(mtype, mdata) async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: diff --git a/core/tests/test_protobuf.py b/core/tests/test_protobuf.py deleted file mode 100644 index 31b69c881..000000000 --- a/core/tests/test_protobuf.py +++ /dev/null @@ -1,129 +0,0 @@ -from common import * - -import protobuf -from trezor.utils import BufferReader, BufferWriter - - -class Message(protobuf.MessageType): - def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None: - self.sint_field = sint_field - self.enum_field = enum_field - - @classmethod - def get_fields(cls): - return { - 1: ("sint_field", protobuf.SVarintType, 0), - 2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0), - } - - -class MessageWithRequiredAndDefault(protobuf.MessageType): - def __init__(self, required_field, default_field) -> None: - self.required_field = required_field - self.default_field = default_field - - @classmethod - def get_fields(cls): - return { - 1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED), - 2: ("default_field", protobuf.SVarintType, -1), - } - - -def load_uvarint(data: bytes) -> int: - reader = BufferReader(data) - return protobuf.load_uvarint(reader) - - -def dump_uvarint(value: int) -> bytearray: - w = bytearray() - protobuf.dump_uvarint(w.extend, value) - return w - - -def dump_message(msg: protobuf.MessageType) -> bytearray: - length = protobuf.count_message(msg) - buffer = bytearray(length) - protobuf.dump_message(BufferWriter(buffer), msg) - return buffer - - -def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType: - return protobuf.load_message(BufferReader(buffer), msg_type) - - -class TestProtobuf(unittest.TestCase): - def test_dump_uvarint(self): - self.assertEqual(dump_uvarint(0), b"\x00") - self.assertEqual(dump_uvarint(1), b"\x01") - self.assertEqual(dump_uvarint(0xFF), b"\xff\x01") - self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07") - with self.assertRaises(ValueError): - dump_uvarint(-1) - - def test_load_uvarint(self): - self.assertEqual(load_uvarint(b"\x00"), 0) - self.assertEqual(load_uvarint(b"\x01"), 1) - self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF) - self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456) - - def test_sint_uint(self): - self.assertEqual(protobuf.uint_to_sint(0), 0) - self.assertEqual(protobuf.sint_to_uint(0), 0) - - self.assertEqual(protobuf.sint_to_uint(-1), 1) - self.assertEqual(protobuf.sint_to_uint(1), 2) - - self.assertEqual(protobuf.uint_to_sint(1), -1) - self.assertEqual(protobuf.uint_to_sint(2), 1) - - # roundtrip: - self.assertEqual( - protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011 - ) - self.assertEqual( - protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32) - ) - - def test_validate_enum(self): - # ok message: - msg = Message(-42, 5) - msg_encoded = dump_message(msg) - nmsg = load_message(Message, msg_encoded) - - self.assertEqual(msg.sint_field, nmsg.sint_field) - self.assertEqual(msg.enum_field, nmsg.enum_field) - - # bad enum value: - msg = Message(-42, 42) - msg_encoded = dump_message(msg) - with self.assertRaises(TypeError): - load_message(Message, msg_encoded) - - def test_required(self): - msg = MessageWithRequiredAndDefault(required_field=1, default_field=2) - msg_encoded = dump_message(msg) - nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded) - - self.assertEqual(nmsg.required_field, 1) - self.assertEqual(nmsg.default_field, 2) - - # try a message without the required_field - msg = MessageWithRequiredAndDefault(required_field=None, default_field=2) - # encoding always succeeds - msg_encoded = dump_message(msg) - with self.assertRaises(ValueError): - load_message(MessageWithRequiredAndDefault, msg_encoded) - - # try a message without the default field - msg = MessageWithRequiredAndDefault(required_field=1, default_field=None) - msg_encoded = dump_message(msg) - nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded) - - self.assertEqual(nmsg.required_field, 1) - self.assertEqual(nmsg.default_field, -1) - - - -if __name__ == "__main__": - unittest.main() diff --git a/core/tests/test_trezor.protobuf.py b/core/tests/test_trezor.protobuf.py new file mode 100644 index 000000000..4576c5ac5 --- /dev/null +++ b/core/tests/test_trezor.protobuf.py @@ -0,0 +1,115 @@ +from common import * + +from trezor import protobuf +from trezor.messages import WebAuthnCredential, EosAsset, Failure, SignMessage + + +def load_uvarint32(data: bytes) -> int: + # use known uint32 field in an all-optional message + buffer = bytearray(len(data) + 1) + buffer[1:] = data + buffer[0] = (1 << 3) | 0 # field number 1, wire type 0 + msg = protobuf.decode(buffer, WebAuthnCredential, False) + return msg.index + + +def load_uvarint64(data: bytes) -> int: + # use known uint64 field in an all-optional message + buffer = bytearray(len(data) + 1) + buffer[1:] = data + buffer[0] = (2 << 3) | 0 # field number 1, wire type 0 + msg = protobuf.decode(buffer, EosAsset, False) + return msg.symbol + + +def dump_uvarint32(value: int) -> bytearray: + # use known uint32 field in an all-optional message + msg = WebAuthnCredential(index=value) + length = protobuf.encoded_length(msg) + buffer = bytearray(length) + protobuf.encode(buffer, msg) + assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0 + return buffer[1:] + + +def dump_uvarint64(value: int) -> bytearray: + # use known uint64 field in an all-optional message + msg = EosAsset(symbol=value) + length = protobuf.encoded_length(msg) + buffer = bytearray(length) + protobuf.encode(buffer, msg) + assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0 + return buffer[1:] + + +def dump_message(msg: protobuf.MessageType) -> bytearray: + length = protobuf.encoded_length(msg) + buffer = bytearray(length) + protobuf.encode(buffer, msg) + return buffer + + +def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType: + return protobuf.decode(buffer, msg_type, False) + + +class TestProtobuf(unittest.TestCase): + def test_dump_uvarint(self): + for dump_uvarint in (dump_uvarint32, dump_uvarint64): + self.assertEqual(dump_uvarint(0), b"\x00") + self.assertEqual(dump_uvarint(1), b"\x01") + self.assertEqual(dump_uvarint(0xFF), b"\xff\x01") + self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07") + with self.assertRaises(ValueError): + dump_uvarint(-1) + + def test_load_uvarint(self): + for load_uvarint in (load_uvarint32, load_uvarint64): + self.assertEqual(load_uvarint(b"\x00"), 0) + self.assertEqual(load_uvarint(b"\x01"), 1) + self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF) + self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456) + + def test_validate_enum(self): + # ok message: + msg = Failure(code=7) + msg_encoded = dump_message(msg) + nmsg = load_message(Failure, msg_encoded) + + self.assertEqual(msg.code, nmsg.code) + + # bad enum value: + msg = Failure(code=1000) + msg_encoded = dump_message(msg) + with self.assertRaises(ValueError): + load_message(Failure, msg_encoded) + + def test_required(self): + msg = SignMessage(message=b"hello", coin_name="foo", script_type=1) + msg_encoded = dump_message(msg) + nmsg = load_message(SignMessage, msg_encoded) + + self.assertEqual(nmsg.message, b"hello") + self.assertEqual(nmsg.coin_name, "foo") + self.assertEqual(nmsg.script_type, 1) + + # try a message without the required_field + msg = SignMessage(message=None) + # encoding always succeeds + msg_encoded = dump_message(msg) + with self.assertRaises(ValueError): + load_message(SignMessage, msg_encoded) + + # try a message without the default field + msg = SignMessage(message=b"hello") + msg.coin_name = None + msg_encoded = dump_message(msg) + nmsg = load_message(SignMessage, msg_encoded) + + self.assertEqual(nmsg.message, b"hello") + self.assertEqual(nmsg.coin_name, "Bitcoin") + + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 3e8cab290..1e2939a4b 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -53,7 +53,7 @@ class TestWireCodecV1(unittest.TestCase): # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) - self.assertEqual(result.data.buffer, b"") + self.assertEqual(result.data, b"") # message should have been read into the buffer self.assertEqual(buffer, b"\x00" * 64) @@ -83,7 +83,7 @@ class TestWireCodecV1(unittest.TestCase): # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) - self.assertEqual(result.data.buffer, message) + self.assertEqual(result.data, message) # message should have been read into the buffer self.assertEqual(buffer, message) @@ -108,7 +108,7 @@ class TestWireCodecV1(unittest.TestCase): # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) - self.assertEqual(result.data.buffer, message) + self.assertEqual(result.data, message) # read should have allocated its own buffer and not touch ours self.assertEqual(buffer, b"\x00") @@ -176,7 +176,7 @@ class TestWireCodecV1(unittest.TestCase): result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) - self.assertEqual(result.data.buffer, message) + self.assertEqual(result.data, message) def test_read_huge_packet(self): PACKET_COUNT = 100_000