From e5895c3413f7872c91550c0a2bf3bc541a56d75c Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Wed, 25 Jun 2025 00:57:12 +0200 Subject: [PATCH] wip --- common/protob/pb2py | 17 ++++++++++++++--- core/embed/rust/src/protobuf/decode.rs | 2 +- core/embed/rust/src/protobuf/defs.rs | 5 +++++ core/embed/rust/src/protobuf/obj.rs | 21 ++++++++++++++++----- core/embed/upymod/qstrdefsport.h | 2 ++ core/mocks/generated/trezorproto.pyi | 4 ++-- core/src/apps/debug/__init__.py | 2 +- core/src/trezor/protobuf.py | 2 +- core/src/trezor/wire/codec/codec_context.py | 2 +- core/src/trezor/wire/message_handler.py | 4 ++-- 10 files changed, 45 insertions(+), 16 deletions(-) diff --git a/common/protob/pb2py b/common/protob/pb2py index d4c6018317..3a1b991edd 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -246,9 +246,8 @@ class ProtoMessage: if "wire_type" in extensions: wire_type = extensions["wire_type"] elif "wire_enum" in extensions: - message_type_enum = find_by_name( - descriptor.enums, extensions["wire_enum"].bytes.decode() - ) + enum_name = extensions["wire_enum"].bytes.decode() + message_type_enum = find_by_name(descriptor.enums, enum_name) custom_message_type = find_by_name(message_type_enum.value, message.name) if message_type is not None and custom_message_type != message_type: # we are specifying two different message ids via MessageType and via the custom wire_enum @@ -388,6 +387,15 @@ class Descriptor: # recursively search for nested types in newly added messages self._nested_types_from_message(message.orig) + # list of all enums containing wire_ids + self.all_message_types = [MESSAGE_TYPE_ENUM] + for m in self.messages: + if "wire_enum" not in m.extensions: + continue + wire_enum_name = m.extensions["wire_enum"].bytes.decode() + if wire_enum_name not in self.all_message_types: + self.all_message_types.append(wire_enum_name) + if not self.messages and not self.enums: raise RuntimeError("No messages and no enums found.") @@ -694,6 +702,9 @@ class RustBlobRenderer: return b"".join( NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name])) for e in enums + # currently we only care about wire_id enums + # omit everything else to save space + if e.name in self.descriptor.all_message_types ) def build_blob_wire(self): diff --git a/core/embed/rust/src/protobuf/decode.rs b/core/embed/rust/src/protobuf/decode.rs index 5c4a21168d..f7fd3ba652 100644 --- a/core/embed/rust/src/protobuf/decode.rs +++ b/core/embed/rust/src/protobuf/decode.rs @@ -240,7 +240,7 @@ impl Decoder { } FieldType::Enum(enum_type) => { let enum_val = num.try_into()?; - if enum_type.values.contains(&enum_val) { + if enum_type.contains(enum_val) { Ok(enum_val.into()) } else { Err(error::invalid_value(field.name.into())) diff --git a/core/embed/rust/src/protobuf/defs.rs b/core/embed/rust/src/protobuf/defs.rs index f1170a663a..a17f610a0f 100644 --- a/core/embed/rust/src/protobuf/defs.rs +++ b/core/embed/rust/src/protobuf/defs.rs @@ -111,6 +111,11 @@ impl EnumDef { get_enum(msg_offset) }) } + + pub fn contains(&self, x: u16) -> bool { + // NOTE: binary_search instead of contains might be faster + self.values.contains(&x) + } } #[repr(C, packed)] diff --git a/core/embed/rust/src/protobuf/obj.rs b/core/embed/rust/src/protobuf/obj.rs index fc3a1f3a1d..feb648a8b0 100644 --- a/core/embed/rust/src/protobuf/obj.rs +++ b/core/embed/rust/src/protobuf/obj.rs @@ -18,7 +18,7 @@ use crate::{ use super::{ decode::{protobuf_decode, Decoder}, - defs::{find_name_by_msg_offset, get_msg, MsgDef}, + defs::{find_name_by_msg_offset, get_msg, EnumDef, MsgDef}, encode::{protobuf_encode, protobuf_len}, }; @@ -283,6 +283,15 @@ unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj { static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t = obj_fn_2!(msg_def_obj_is_type_of); +fn validate_wire_id_enum(wire_id: u16, enum_name: Qstr) -> Result<(), Error> { + let def = + EnumDef::for_name(enum_name.to_u16()).ok_or_else(|| Error::KeyError(enum_name.into()))?; + if !def.contains(wire_id) { + return Err(Error::OutOfRange); + } + Ok(()) +} + #[no_mangle] pub extern "C" fn protobuf_debug_msg_type() -> &'static Type { MsgObj::obj_type() @@ -303,9 +312,11 @@ pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj { unsafe { util::try_or_raise(block) } } -pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj { +pub extern "C" fn protobuf_type_for_wire(wire_id: Obj, enum_name: Obj) -> Obj { let block = || { let wire_id = u16::try_from(wire_id)?; + let enum_name = Qstr::try_from(enum_name)?; + validate_wire_id_enum(wire_id, enum_name)?; let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?; let obj = MsgDefObj::alloc(def)?.into(); Ok(obj) @@ -340,9 +351,9 @@ pub static mp_module_trezorproto: Module = obj_module! { /// """Find the message definition for the given protobuf name.""" Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(), - /// def type_for_wire(wire_id: int) -> type[MessageType]: - /// """Find the message definition for the given wire type (numeric identifier).""" - Qstr::MP_QSTR_type_for_wire => obj_fn_1!(protobuf_type_for_wire).as_obj(), + /// def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]: + /// """Find the message definition for the given wire type (numeric identifier). TODO document enum_name""" + Qstr::MP_QSTR_type_for_wire => obj_fn_2!(protobuf_type_for_wire).as_obj(), /// def decode( /// buffer: bytes, diff --git a/core/embed/upymod/qstrdefsport.h b/core/embed/upymod/qstrdefsport.h index 52df6091c6..014b3dca5f 100644 --- a/core/embed/upymod/qstrdefsport.h +++ b/core/embed/upymod/qstrdefsport.h @@ -386,11 +386,13 @@ Q(workflow_handlers) Q(writers) #if USE_THP +Q(ThpMessageType) Q(ThpPairingMethod) Q(apps.thp) Q(apps.thp.credential_manager) Q(credential_manager) Q(thp) +Q(trezor.enums.ThpMessageType) Q(trezor.enums.ThpPairingMethod) #endif diff --git a/core/mocks/generated/trezorproto.pyi b/core/mocks/generated/trezorproto.pyi index 530ef0963f..b6580a7fbe 100644 --- a/core/mocks/generated/trezorproto.pyi +++ b/core/mocks/generated/trezorproto.pyi @@ -23,8 +23,8 @@ def type_for_name(name: str) -> type[MessageType]: # rust/src/protobuf/obj.rs -def type_for_wire(wire_id: int) -> type[MessageType]: - """Find the message definition for the given wire type (numeric identifier).""" +def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]: + """Find the message definition for the given wire type (numeric identifier). TODO document enum_name""" # rust/src/protobuf/obj.rs diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 64032d6325..ad6fe2913f 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -405,7 +405,7 @@ if __debug__: req_type = None try: - req_type = protobuf.type_for_wire(msg.type) + req_type = protobuf.type_for_wire(msg.type, "MessageType") msg_type = req_type.MESSAGE_NAME except Exception: msg_type = f"{msg.type} - unknown message type" diff --git a/core/src/trezor/protobuf.py b/core/src/trezor/protobuf.py index 9d9210ddb4..74afd3c72d 100644 --- a/core/src/trezor/protobuf.py +++ b/core/src/trezor/protobuf.py @@ -17,7 +17,7 @@ def load_message_buffer( msg_wire_type: int, experimental_enabled: bool = True, ) -> MessageType: - msg_type = type_for_wire(msg_wire_type) + msg_type = type_for_wire(msg_wire_type, "MessageType") return decode(buffer, msg_type, experimental_enabled) diff --git a/core/src/trezor/wire/codec/codec_context.py b/core/src/trezor/wire/codec/codec_context.py index 719de212c4..bd5209067d 100644 --- a/core/src/trezor/wire/codec/codec_context.py +++ b/core/src/trezor/wire/codec/codec_context.py @@ -63,7 +63,7 @@ class CodecContext(Context): raise UnexpectedMessageException(msg) if expected_type is None: - expected_type = protobuf.type_for_wire(msg.type) + expected_type = protobuf.type_for_wire(msg.type, "MessageType") if __debug__: log.debug( diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index d2cb718c0a..fa28fbaf67 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -64,7 +64,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: if __debug__: try: - msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + msg_type = protobuf.type_for_wire(msg.type, "MessageType").MESSAGE_NAME except Exception: msg_type = f"{msg.type} - unknown message type" log.info( @@ -91,7 +91,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool: try: # Find a protobuf.MessageType subclass that describes this # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) + req_type = protobuf.type_for_wire(msg.type, "MessageType") # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed.