mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-30 20:02:34 +00:00
wip
This commit is contained in:
parent
1f1fef21e5
commit
e5895c3413
@ -246,9 +246,8 @@ class ProtoMessage:
|
|||||||
if "wire_type" in extensions:
|
if "wire_type" in extensions:
|
||||||
wire_type = extensions["wire_type"]
|
wire_type = extensions["wire_type"]
|
||||||
elif "wire_enum" in extensions:
|
elif "wire_enum" in extensions:
|
||||||
message_type_enum = find_by_name(
|
enum_name = extensions["wire_enum"].bytes.decode()
|
||||||
descriptor.enums, extensions["wire_enum"].bytes.decode()
|
message_type_enum = find_by_name(descriptor.enums, enum_name)
|
||||||
)
|
|
||||||
custom_message_type = find_by_name(message_type_enum.value, message.name)
|
custom_message_type = find_by_name(message_type_enum.value, message.name)
|
||||||
if message_type is not None and custom_message_type != message_type:
|
if message_type is not None and custom_message_type != message_type:
|
||||||
# we are specifying two different message ids via MessageType and via the custom wire_enum
|
# we are specifying two different message ids via MessageType and via the custom wire_enum
|
||||||
@ -388,6 +387,15 @@ class Descriptor:
|
|||||||
# recursively search for nested types in newly added messages
|
# recursively search for nested types in newly added messages
|
||||||
self._nested_types_from_message(message.orig)
|
self._nested_types_from_message(message.orig)
|
||||||
|
|
||||||
|
# list of all enums containing wire_ids
|
||||||
|
self.all_message_types = [MESSAGE_TYPE_ENUM]
|
||||||
|
for m in self.messages:
|
||||||
|
if "wire_enum" not in m.extensions:
|
||||||
|
continue
|
||||||
|
wire_enum_name = m.extensions["wire_enum"].bytes.decode()
|
||||||
|
if wire_enum_name not in self.all_message_types:
|
||||||
|
self.all_message_types.append(wire_enum_name)
|
||||||
|
|
||||||
if not self.messages and not self.enums:
|
if not self.messages and not self.enums:
|
||||||
raise RuntimeError("No messages and no enums found.")
|
raise RuntimeError("No messages and no enums found.")
|
||||||
|
|
||||||
@ -694,6 +702,9 @@ class RustBlobRenderer:
|
|||||||
return b"".join(
|
return b"".join(
|
||||||
NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name]))
|
NAME_ENTRY.build((self.qstr_map[e.name], self.enum_map[e.name]))
|
||||||
for e in enums
|
for e in enums
|
||||||
|
# currently we only care about wire_id enums
|
||||||
|
# omit everything else to save space
|
||||||
|
if e.name in self.descriptor.all_message_types
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_blob_wire(self):
|
def build_blob_wire(self):
|
||||||
|
@ -240,7 +240,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
FieldType::Enum(enum_type) => {
|
FieldType::Enum(enum_type) => {
|
||||||
let enum_val = num.try_into()?;
|
let enum_val = num.try_into()?;
|
||||||
if enum_type.values.contains(&enum_val) {
|
if enum_type.contains(enum_val) {
|
||||||
Ok(enum_val.into())
|
Ok(enum_val.into())
|
||||||
} else {
|
} else {
|
||||||
Err(error::invalid_value(field.name.into()))
|
Err(error::invalid_value(field.name.into()))
|
||||||
|
@ -111,6 +111,11 @@ impl EnumDef {
|
|||||||
get_enum(msg_offset)
|
get_enum(msg_offset)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn contains(&self, x: u16) -> bool {
|
||||||
|
// NOTE: binary_search instead of contains might be faster
|
||||||
|
self.values.contains(&x)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[repr(C, packed)]
|
#[repr(C, packed)]
|
||||||
|
@ -18,7 +18,7 @@ use crate::{
|
|||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
decode::{protobuf_decode, Decoder},
|
decode::{protobuf_decode, Decoder},
|
||||||
defs::{find_name_by_msg_offset, get_msg, MsgDef},
|
defs::{find_name_by_msg_offset, get_msg, EnumDef, MsgDef},
|
||||||
encode::{protobuf_encode, protobuf_len},
|
encode::{protobuf_encode, protobuf_len},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -283,6 +283,15 @@ unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj {
|
|||||||
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t =
|
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t =
|
||||||
obj_fn_2!(msg_def_obj_is_type_of);
|
obj_fn_2!(msg_def_obj_is_type_of);
|
||||||
|
|
||||||
|
fn validate_wire_id_enum(wire_id: u16, enum_name: Qstr) -> Result<(), Error> {
|
||||||
|
let def =
|
||||||
|
EnumDef::for_name(enum_name.to_u16()).ok_or_else(|| Error::KeyError(enum_name.into()))?;
|
||||||
|
if !def.contains(wire_id) {
|
||||||
|
return Err(Error::OutOfRange);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
pub extern "C" fn protobuf_debug_msg_type() -> &'static Type {
|
pub extern "C" fn protobuf_debug_msg_type() -> &'static Type {
|
||||||
MsgObj::obj_type()
|
MsgObj::obj_type()
|
||||||
@ -303,9 +312,11 @@ pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj {
|
|||||||
unsafe { util::try_or_raise(block) }
|
unsafe { util::try_or_raise(block) }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj {
|
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj, enum_name: Obj) -> Obj {
|
||||||
let block = || {
|
let block = || {
|
||||||
let wire_id = u16::try_from(wire_id)?;
|
let wire_id = u16::try_from(wire_id)?;
|
||||||
|
let enum_name = Qstr::try_from(enum_name)?;
|
||||||
|
validate_wire_id_enum(wire_id, enum_name)?;
|
||||||
let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?;
|
let def = MsgDef::for_wire_id(wire_id).ok_or_else(|| Error::KeyError(wire_id.into()))?;
|
||||||
let obj = MsgDefObj::alloc(def)?.into();
|
let obj = MsgDefObj::alloc(def)?.into();
|
||||||
Ok(obj)
|
Ok(obj)
|
||||||
@ -340,9 +351,9 @@ pub static mp_module_trezorproto: Module = obj_module! {
|
|||||||
/// """Find the message definition for the given protobuf name."""
|
/// """Find the message definition for the given protobuf name."""
|
||||||
Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(),
|
Qstr::MP_QSTR_type_for_name => obj_fn_1!(protobuf_type_for_name).as_obj(),
|
||||||
|
|
||||||
/// def type_for_wire(wire_id: int) -> type[MessageType]:
|
/// def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
|
||||||
/// """Find the message definition for the given wire type (numeric identifier)."""
|
/// """Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
|
||||||
Qstr::MP_QSTR_type_for_wire => obj_fn_1!(protobuf_type_for_wire).as_obj(),
|
Qstr::MP_QSTR_type_for_wire => obj_fn_2!(protobuf_type_for_wire).as_obj(),
|
||||||
|
|
||||||
/// def decode(
|
/// def decode(
|
||||||
/// buffer: bytes,
|
/// buffer: bytes,
|
||||||
|
@ -386,11 +386,13 @@ Q(workflow_handlers)
|
|||||||
Q(writers)
|
Q(writers)
|
||||||
|
|
||||||
#if USE_THP
|
#if USE_THP
|
||||||
|
Q(ThpMessageType)
|
||||||
Q(ThpPairingMethod)
|
Q(ThpPairingMethod)
|
||||||
Q(apps.thp)
|
Q(apps.thp)
|
||||||
Q(apps.thp.credential_manager)
|
Q(apps.thp.credential_manager)
|
||||||
Q(credential_manager)
|
Q(credential_manager)
|
||||||
Q(thp)
|
Q(thp)
|
||||||
|
Q(trezor.enums.ThpMessageType)
|
||||||
Q(trezor.enums.ThpPairingMethod)
|
Q(trezor.enums.ThpPairingMethod)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -23,8 +23,8 @@ def type_for_name(name: str) -> type[MessageType]:
|
|||||||
|
|
||||||
|
|
||||||
# rust/src/protobuf/obj.rs
|
# rust/src/protobuf/obj.rs
|
||||||
def type_for_wire(wire_id: int) -> type[MessageType]:
|
def type_for_wire(wire_id: int, enum_name: str) -> type[MessageType]:
|
||||||
"""Find the message definition for the given wire type (numeric identifier)."""
|
"""Find the message definition for the given wire type (numeric identifier). TODO document enum_name"""
|
||||||
|
|
||||||
|
|
||||||
# rust/src/protobuf/obj.rs
|
# rust/src/protobuf/obj.rs
|
||||||
|
@ -405,7 +405,7 @@ if __debug__:
|
|||||||
|
|
||||||
req_type = None
|
req_type = None
|
||||||
try:
|
try:
|
||||||
req_type = protobuf.type_for_wire(msg.type)
|
req_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||||
msg_type = req_type.MESSAGE_NAME
|
msg_type = req_type.MESSAGE_NAME
|
||||||
except Exception:
|
except Exception:
|
||||||
msg_type = f"{msg.type} - unknown message type"
|
msg_type = f"{msg.type} - unknown message type"
|
||||||
|
@ -17,7 +17,7 @@ def load_message_buffer(
|
|||||||
msg_wire_type: int,
|
msg_wire_type: int,
|
||||||
experimental_enabled: bool = True,
|
experimental_enabled: bool = True,
|
||||||
) -> MessageType:
|
) -> MessageType:
|
||||||
msg_type = type_for_wire(msg_wire_type)
|
msg_type = type_for_wire(msg_wire_type, "MessageType")
|
||||||
return decode(buffer, msg_type, experimental_enabled)
|
return decode(buffer, msg_type, experimental_enabled)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class CodecContext(Context):
|
|||||||
raise UnexpectedMessageException(msg)
|
raise UnexpectedMessageException(msg)
|
||||||
|
|
||||||
if expected_type is None:
|
if expected_type is None:
|
||||||
expected_type = protobuf.type_for_wire(msg.type)
|
expected_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
|
@ -64,7 +64,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
|||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
try:
|
try:
|
||||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
msg_type = protobuf.type_for_wire(msg.type, "MessageType").MESSAGE_NAME
|
||||||
except Exception:
|
except Exception:
|
||||||
msg_type = f"{msg.type} - unknown message type"
|
msg_type = f"{msg.type} - unknown message type"
|
||||||
log.info(
|
log.info(
|
||||||
@ -91,7 +91,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
|
|||||||
try:
|
try:
|
||||||
# Find a protobuf.MessageType subclass that describes this
|
# Find a protobuf.MessageType subclass that describes this
|
||||||
# message. Raises if the type is not found.
|
# message. Raises if the type is not found.
|
||||||
req_type = protobuf.type_for_wire(msg.type)
|
req_type = protobuf.type_for_wire(msg.type, "MessageType")
|
||||||
|
|
||||||
# Try to decode the message according to schema from
|
# Try to decode the message according to schema from
|
||||||
# `req_type`. Raises if the message is malformed.
|
# `req_type`. Raises if the message is malformed.
|
||||||
|
Loading…
Reference in New Issue
Block a user