mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-30 11:52:34 +00:00
wip
This commit is contained in:
parent
1f1fef21e5
commit
e5895c3413
@ -246,9 +246,8 @@ class ProtoMessage:
|
||||
if "wire_type" in extensions:
|
||||
wire_type = extensions["wire_type"]
|
||||
elif "wire_enum" in extensions:
|
||||
message_type_enum = find_by_name(
|
||||
descriptor.enums, extensions["wire_enum"].bytes.decode()
|
||||
)
|
||||
enum_name = extensions["wire_enum"].bytes.decode()
|
||||
message_type_enum = find_by_name(descriptor.enums, enum_name)
|
||||
custom_message_type = find_by_name(message_type_enum.value, message.name)
|
||||
if message_type is not None and custom_message_type != message_type:
|
||||
# we are specifying two different message ids via MessageType and via the custom wire_enum
|
||||
@ -388,6 +387,15 @@ class Descriptor:
|
||||
# recursively search for nested types in newly added messages
|
||||
self._nested_types_from_message(message.orig)
|
||||
|
||||
# list of all enums containing wire_ids
|
||||
self.all_message_types = [MESSAGE_TYPE_ENUM]
|
||||
for m in self.messages:
|
||||
if "wire_enum" not in m.extensions:
|
||||
continue
|
||||
wire_enum_name = m.extensions["wire_enum"].bytes.decode()
|
||||
if wire_enum_name not in self.all_message_types:
|
||||
self.all_message_types.append(wire_enum_name)
|
||||
|
||||
if not self.messages and not self.enums:
|
||||
raise RuntimeError("No messages and no enums found.")
|
||||
|
||||
@ -694,6 +702,9 @@ class RustBlobRenderer:
|
||||
return b"".join(
|
||||
NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name]))
|
||||
for e in enums
|
||||
# currently we only care about wire_id enums
|
||||
# omit everything else to save space
|
||||
if e.name in self.descriptor.all_message_types
|
||||
)
|
||||
|
||||
def build_blob_wire(self):
|
||||
|
@ -240,7 +240,7 @@ impl Decoder {
|
||||
}
|
||||
FieldType::Enum(enum_type) => {
|
||||
let enum_val = num.try_into()?;
|
||||
if enum_type.values.contains(&enum_val) {
|
||||
if enum_type.contains(enum_val) {
|
||||
Ok(enum_val.into())
|
||||
} else {
|
||||
Err(error::invalid_value(field.name.into()))
|
||||
|
@ -111,6 +111,11 @@ impl EnumDef {
|
||||
get_enum(msg_offset)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn contains(&self, x: u16) -> bool {
|
||||
// NOTE: binary_search instead of contains might be faster
|
||||
self.values.contains(&x)
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
|
@ -18,7 +18,7 @@ use crate::{
|
||||
|
||||
use super::{
|
||||
decode::{protobuf_decode, Decoder},
|
||||
defs::{find_name_by_msg_offset, get_msg, MsgDef},
|
||||
defs::{find_name_by_msg_offset, get_msg, EnumDef, MsgDef},
|
||||
encode::{protobuf_encode, protobuf_len},
|
||||
};
|
||||
|
||||
@ -283,6 +283,15 @@ unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj {
|
||||
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t =
|
||||
obj_fn_2!(msg_def_obj_is_type_of);
|
||||
|
||||
fn validate_wire_id_enum(wire_id: u16, enum_name: Qstr) -> Result<(), Error> {
|
||||
let def =
|
||||
EnumDef::for_name(enum_name.to_u16()).ok_or_else(|| Error::KeyError(enum_name.into()))?;
|
||||
if !def.contains(wire_id) {
|
||||
return Err(Error::OutOfRange);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_debug_msg_type() -> &'static Type {
|
||||
MsgObj::obj_type()
|
||||
@ -303,9 +312,11 @@ pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj {
|
||||
unsafe { util::try_or_raise(block) }
|
||||
}
|
||||
|
||||
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj {
|
||||
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj, enum_name: Obj) -> Obj {
|
||||
let block = || {
|
||||
let wire_id = u16::try_from(wire_id)?;
|
||||
let enum_name = Qstr::try_from(enum_name)?;
|
||||
validate_wire_id_enum(wire_id, enum_name)?;
|
||||
let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?;
|
||||
let obj = MsgDefObj::alloc(def)?.into();
|
||||
Ok(obj)
|
||||
@ -340,9 +351,9 @@ pub static mp_module_trezorproto: Module = obj_module! {
|
||||
/// """Find the message definition for the given protobuf name."""
|
||||
Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(),
|
||||
|
||||
/// def type_for_wire(wire_id: int) -> type[MessageType]:
|
||||
/// """Find the message definition for the given wire type (numeric identifier)."""
|
||||
Qstr::MP_QSTR_type_for_wire => obj_fn_1!(protobuf_type_for_wire).as_obj(),
|
||||
/// def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
|
||||
/// """Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
|
||||
Qstr::MP_QSTR_type_for_wire => obj_fn_2!(protobuf_type_for_wire).as_obj(),
|
||||
|
||||
/// def decode(
|
||||
/// buffer: bytes,
|
||||
|
@ -386,11 +386,13 @@ Q(workflow_handlers)
|
||||
Q(writers)
|
||||
|
||||
#if USE_THP
|
||||
Q(ThpMessageType)
|
||||
Q(ThpPairingMethod)
|
||||
Q(apps.thp)
|
||||
Q(apps.thp.credential_manager)
|
||||
Q(credential_manager)
|
||||
Q(thp)
|
||||
Q(trezor.enums.ThpMessageType)
|
||||
Q(trezor.enums.ThpPairingMethod)
|
||||
#endif
|
||||
|
||||
|
@ -23,8 +23,8 @@ def type_for_name(name: str) -> type[MessageType]:
|
||||
|
||||
|
||||
# rust/src/protobuf/obj.rs
|
||||
def type_for_wire(wire_id: int) -> type[MessageType]:
|
||||
"""Find the message definition for the given wire type (numeric identifier)."""
|
||||
def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
|
||||
"""Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
|
||||
|
||||
|
||||
# rust/src/protobuf/obj.rs
|
||||
|
@ -405,7 +405,7 @@ if __debug__:
|
||||
|
||||
req_type = None
|
||||
try:
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
req_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||
msg_type = req_type.MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
|
@ -17,7 +17,7 @@ def load_message_buffer(
|
||||
msg_wire_type: int,
|
||||
experimental_enabled: bool = True,
|
||||
) -> MessageType:
|
||||
msg_type = type_for_wire(msg_wire_type)
|
||||
msg_type = type_for_wire(msg_wire_type, "MessageType")
|
||||
return decode(buffer, msg_type, experimental_enabled)
|
||||
|
||||
|
||||
|
@ -63,7 +63,7 @@ class CodecContext(Context):
|
||||
raise UnexpectedMessageException(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
expected_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
|
@ -64,7 +64,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
||||
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
msg_type = protobuf.type_for_wire(msg.type, "MessageType").MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
log.info(
|
||||
@ -91,7 +91,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
req_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
|
Loading…
Reference in New Issue
Block a user