1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-30 20:02:34 +00:00
This commit is contained in:
Martin Milata 2025-06-25 00:57:12 +02:00
parent 1f1fef21e5
commit e5895c3413
10 changed files with 45 additions and 16 deletions

View File

@ -246,9 +246,8 @@ class ProtoMessage:
if "wire_type" in extensions: if "wire_type" in extensions:
wire_type = extensions["wire_type"] wire_type = extensions["wire_type"]
elif "wire_enum" in extensions: elif "wire_enum" in extensions:
message_type_enum = find_by_name( enum_name = extensions["wire_enum"].bytes.decode()
descriptor.enums, extensions["wire_enum"].bytes.decode() message_type_enum = find_by_name(descriptor.enums, enum_name)
)
custom_message_type = find_by_name(message_type_enum.value, message.name) custom_message_type = find_by_name(message_type_enum.value, message.name)
if message_type is not None and custom_message_type != message_type: if message_type is not None and custom_message_type != message_type:
# we are specifying two different message ids via MessageType and via the custom wire_enum # we are specifying two different message ids via MessageType and via the custom wire_enum
@ -388,6 +387,15 @@ class Descriptor:
# recursively search for nested types in newly added messages # recursively search for nested types in newly added messages
self._nested_types_from_message(message.orig) self._nested_types_from_message(message.orig)
# list of all enums containing wire_ids
self.all_message_types = [MESSAGE_TYPE_ENUM]
for m in self.messages:
if "wire_enum" not in m.extensions:
continue
wire_enum_name = m.extensions["wire_enum"].bytes.decode()
if wire_enum_name not in self.all_message_types:
self.all_message_types.append(wire_enum_name)
if not self.messages and not self.enums: if not self.messages and not self.enums:
raise RuntimeError("No messages and no enums found.") raise RuntimeError("No messages and no enums found.")
@ -694,6 +702,9 @@ class RustBlobRenderer:
return b"".join( return b"".join(
NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name])) NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name]))
for e in enums for e in enums
# currently we only care about wire_id enums
# omit everything else to save space
if e.name in self.descriptor.all_message_types
) )
def build_blob_wire(self): def build_blob_wire(self):

View File

@ -240,7 +240,7 @@ impl Decoder {
} }
FieldType::Enum(enum_type) => { FieldType::Enum(enum_type) => {
let enum_val = num.try_into()?; let enum_val = num.try_into()?;
if enum_type.values.contains(&enum_val) { if enum_type.contains(enum_val) {
Ok(enum_val.into()) Ok(enum_val.into())
} else { } else {
Err(error::invalid_value(field.name.into())) Err(error::invalid_value(field.name.into()))

View File

@ -111,6 +111,11 @@ impl EnumDef {
get_enum(msg_offset) get_enum(msg_offset)
}) })
} }
pub fn contains(&self, x: u16) -> bool {
// NOTE: binary_search instead of contains might be faster
self.values.contains(&x)
}
} }
#[repr(C, packed)] #[repr(C, packed)]

View File

@ -18,7 +18,7 @@ use crate::{
use super::{ use super::{
decode::{protobuf_decode, Decoder}, decode::{protobuf_decode, Decoder},
defs::{find_name_by_msg_offset, get_msg, MsgDef}, defs::{find_name_by_msg_offset, get_msg, EnumDef, MsgDef},
encode::{protobuf_encode, protobuf_len}, encode::{protobuf_encode, protobuf_len},
}; };
@ -283,6 +283,15 @@ unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj {
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t = static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t =
obj_fn_2!(msg_def_obj_is_type_of); obj_fn_2!(msg_def_obj_is_type_of);
fn validate_wire_id_enum(wire_id: u16, enum_name: Qstr) -> Result<(), Error> {
let def =
EnumDef::for_name(enum_name.to_u16()).ok_or_else(|| Error::KeyError(enum_name.into()))?;
if !def.contains(wire_id) {
return Err(Error::OutOfRange);
}
Ok(())
}
#[no_mangle] #[no_mangle]
pub extern "C" fn protobuf_debug_msg_type() -> &'static Type { pub extern "C" fn protobuf_debug_msg_type() -> &'static Type {
MsgObj::obj_type() MsgObj::obj_type()
@ -303,9 +312,11 @@ pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj {
unsafe { util::try_or_raise(block) } unsafe { util::try_or_raise(block) }
} }
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj { pub extern "C" fn protobuf_type_for_wire(wire_id: Obj, enum_name: Obj) -> Obj {
let block = || { let block = || {
let wire_id = u16::try_from(wire_id)?; let wire_id = u16::try_from(wire_id)?;
let enum_name = Qstr::try_from(enum_name)?;
validate_wire_id_enum(wire_id, enum_name)?;
let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?; let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?;
let obj = MsgDefObj::alloc(def)?.into(); let obj = MsgDefObj::alloc(def)?.into();
Ok(obj) Ok(obj)
@ -340,9 +351,9 @@ pub static mp_module_trezorproto: Module = obj_module! {
/// """Find the message definition for the given protobuf name.""" /// """Find the message definition for the given protobuf name."""
Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(), Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(),
/// def type_for_wire(wire_id: int) -> type[MessageType]: /// def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
/// """Find the message definition for the given wire type (numeric identifier).""" /// """Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
Qstr::MP_QSTR_type_for_wire => obj_fn_1!(protobuf_type_for_wire).as_obj(), Qstr::MP_QSTR_type_for_wire => obj_fn_2!(protobuf_type_for_wire).as_obj(),
/// def decode( /// def decode(
/// buffer: bytes, /// buffer: bytes,

View File

@ -386,11 +386,13 @@ Q(workflow_handlers)
Q(writers) Q(writers)
#if USE_THP #if USE_THP
Q(ThpMessageType)
Q(ThpPairingMethod) Q(ThpPairingMethod)
Q(apps.thp) Q(apps.thp)
Q(apps.thp.credential_manager) Q(apps.thp.credential_manager)
Q(credential_manager) Q(credential_manager)
Q(thp) Q(thp)
Q(trezor.enums.ThpMessageType)
Q(trezor.enums.ThpPairingMethod) Q(trezor.enums.ThpPairingMethod)
#endif #endif

View File

@ -23,8 +23,8 @@ def type_for_name(name: str) -> type[MessageType]:
# rust/src/protobuf/obj.rs # rust/src/protobuf/obj.rs
def type_for_wire(wire_id: int) -> type[MessageType]: def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
"""Find the message definition for the given wire type (numeric identifier).""" """Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
# rust/src/protobuf/obj.rs # rust/src/protobuf/obj.rs

View File

@ -405,7 +405,7 @@ if __debug__:
req_type = None req_type = None
try: try:
req_type = protobuf.type_for_wire(msg.type) req_type = protobuf.type_for_wire(msg.type, "MessageType")
msg_type = req_type.MESSAGE_NAME msg_type = req_type.MESSAGE_NAME
except Exception: except Exception:
msg_type = f"{msg.type} - unknown message type" msg_type = f"{msg.type} - unknown message type"

View File

@ -17,7 +17,7 @@ def load_message_buffer(
msg_wire_type: int, msg_wire_type: int,
experimental_enabled: bool = True, experimental_enabled: bool = True,
) -> MessageType: ) -> MessageType:
msg_type = type_for_wire(msg_wire_type) msg_type = type_for_wire(msg_wire_type, "MessageType")
return decode(buffer, msg_type, experimental_enabled) return decode(buffer, msg_type, experimental_enabled)

View File

@ -63,7 +63,7 @@ class CodecContext(Context):
raise UnexpectedMessageException(msg) raise UnexpectedMessageException(msg)
if expected_type is None: if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type) expected_type = protobuf.type_for_wire(msg.type, "MessageType")
if __debug__: if __debug__:
log.debug( log.debug(

View File

@ -64,7 +64,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
if __debug__: if __debug__:
try: try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME msg_type = protobuf.type_for_wire(msg.type, "MessageType").MESSAGE_NAME
except Exception: except Exception:
msg_type = f"{msg.type} - unknown message type" msg_type = f"{msg.type} - unknown message type"
log.info( log.info(
@ -91,7 +91,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
try: try:
# Find a protobuf.MessageType subclass that describes this # Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found. # message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type) req_type = protobuf.type_for_wire(msg.type, "MessageType")
# Try to decode the message according to schema from # Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed. # `req_type`. Raises if the message is malformed.