diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 050ba9b06..0453aeeb9 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -174,6 +174,11 @@ SOURCE_MOD += [ 'embed/extmod/modtrezorutils/modtrezorutils.c', ] +# rust mods +SOURCE_MOD += [ + 'embed/extmod/rustmods/modtrezorproto.c', +] + # modutime SOURCE_MOD += [ 'embed/firmware/modutime.c', diff --git a/core/SConscript.unix b/core/SConscript.unix index 982dd94c9..fb5f957fc 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -171,6 +171,11 @@ SOURCE_MOD += [ 'embed/extmod/modtrezorutils/modtrezorutils.c', ] +# rust mods +SOURCE_MOD += [ + 'embed/extmod/rustmods/modtrezorproto.c', +] + # modutime SOURCE_MOD += [ 'vendor/micropython/ports/unix/modtime.c', diff --git a/core/embed/extmod/rustmods/modtrezorproto.c b/core/embed/extmod/rustmods/modtrezorproto.c new file mode 100644 index 000000000..43194a0d9 --- /dev/null +++ b/core/embed/extmod/rustmods/modtrezorproto.c @@ -0,0 +1,86 @@ +/* + * This file is part of the Trezor project, https://trezor.io/ + * + * Copyright (c) SatoshiLabs + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "py/runtime.h" + +#if MICROPY_PY_TREZORPROTO + +#include "librust.h" + +/// from trezor.protobuf import MessageType +/// T = TypeVar("T", bound=MessageType) + +/// def type_for_name(name: str) -> Type[MessageType]: +/// """Find the message definition for the given protobuf name.""" +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_name_obj, + protobuf_type_for_name); + +/// def type_for_wire(wire_type: int) -> Type[MessageType]: +/// """Find the message definition for the given wire type (numeric +/// identifier).""" +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_wire_obj, + protobuf_type_for_wire); + +/// def decode( +/// buffer: bytes, +/// msg_type: Type[T], +/// enable_experimental: bool, +/// ) -> T: +/// """Decode data in the buffer into the specified message type.""" +STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorutils_protobuf_decode_obj, + protobuf_decode); + +/// def encoded_length(msg: MessageType) -> int: +/// """Calculate length of encoding of the specified message.""" +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_encoded_length_obj, + protobuf_len); + +/// def encode(buffer: bytearray, msg: MessageType) -> int: +/// """Encode the message into the specified buffer. Return length of +/// encoding.""" +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_protobuf_encode_obj, + protobuf_encode); + +STATIC const mp_rom_map_elem_t mp_module_trezorproto_globals_table[] = { + {MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorproto)}, + + {MP_ROM_QSTR(MP_QSTR_type_for_name), + MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_name_obj)}, + {MP_ROM_QSTR(MP_QSTR_type_for_wire), + MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_wire_obj)}, + {MP_ROM_QSTR(MP_QSTR_decode), + MP_ROM_PTR(&mod_trezorutils_protobuf_decode_obj)}, + {MP_ROM_QSTR(MP_QSTR_encoded_length), + MP_ROM_PTR(&mod_trezorutils_protobuf_encoded_length_obj)}, + {MP_ROM_QSTR(MP_QSTR_encode), + MP_ROM_PTR(&mod_trezorutils_protobuf_encode_obj)}, +}; + +STATIC MP_DEFINE_CONST_DICT(mp_module_trezorproto_globals, + mp_module_trezorproto_globals_table); + +const mp_obj_module_t mp_module_trezorproto = { + .base = {&mp_type_module}, + .globals = (mp_obj_dict_t *)&mp_module_trezorproto_globals, +}; + +MP_REGISTER_MODULE(MP_QSTR_trezorproto, mp_module_trezorproto, + MICROPY_PY_TREZORPROTO); + +#endif // MICROPY_PY_TREZORPROTO diff --git a/core/embed/firmware/mpconfigport.h b/core/embed/firmware/mpconfigport.h index 333940472..086630e83 100644 --- a/core/embed/firmware/mpconfigport.h +++ b/core/embed/firmware/mpconfigport.h @@ -156,6 +156,7 @@ #define MICROPY_PY_TREZORIO (1) #define MICROPY_PY_TREZORUI (1) #define MICROPY_PY_TREZORUTILS (1) +#define MICROPY_PY_TREZORPROTO (1) #ifdef SYSTEM_VIEW #define MP_PLAT_PRINT_STRN(str, len) segger_print(str, len) diff --git a/core/embed/rust/.cargo/config.toml b/core/embed/rust/.cargo/config.toml index e0a2aaa11..e19eadf80 100644 --- a/core/embed/rust/.cargo/config.toml +++ b/core/embed/rust/.cargo/config.toml @@ -1,2 +1,2 @@ [build] -target-dir = "../../build/rust" +target-dir = "../../build/unix/rust" diff --git a/core/embed/rust/librust.h b/core/embed/rust/librust.h new file mode 100644 index 000000000..c79d661c1 --- /dev/null +++ b/core/embed/rust/librust.h @@ -0,0 +1,8 @@ +#include "librust_qstr.h" + +mp_obj_t protobuf_type_for_name(mp_obj_t name); +mp_obj_t protobuf_type_for_wire(mp_obj_t wire_id); +mp_obj_t protobuf_decode(mp_obj_t buf, mp_obj_t def, + mp_obj_t enable_experimental); +mp_obj_t protobuf_len(mp_obj_t obj); +mp_obj_t protobuf_encode(mp_obj_t buf, mp_obj_t obj); diff --git a/core/embed/rust/librust_qstr.h b/core/embed/rust/librust_qstr.h new file mode 100644 index 000000000..6b0b0cf6b --- /dev/null +++ b/core/embed/rust/librust_qstr.h @@ -0,0 +1,11 @@ +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wunused-function" + +static void _librust_qstrs(void) { + // protobuf + MP_QSTR_Msg; + MP_QSTR_MsgDef; + MP_QSTR_is_type_of; + MP_QSTR_MESSAGE_WIRE_TYPE; + MP_QSTR_MESSAGE_NAME; +} diff --git a/core/embed/rust/src/error.rs b/core/embed/rust/src/error.rs index 8cb2622af..9977d9161 100644 --- a/core/embed/rust/src/error.rs +++ b/core/embed/rust/src/error.rs @@ -9,6 +9,7 @@ pub enum Error { InvalidType, NotBuffer, NotInt, + InvalidOperation, } impl Error { @@ -22,6 +23,7 @@ impl Error { Error::InvalidType => cstr("InvalidType\0"), Error::NotBuffer => cstr("NotBuffer\0"), Error::NotInt => cstr("NotInt\0"), + Error::InvalidOperation => cstr("InvalidOperation\0"), } } } diff --git a/core/embed/rust/src/lib.rs b/core/embed/rust/src/lib.rs index 93de520a5..8f4eaf7c3 100644 --- a/core/embed/rust/src/lib.rs +++ b/core/embed/rust/src/lib.rs @@ -6,6 +6,7 @@ mod error; #[macro_use] mod micropython; +mod protobuf; mod trezorhal; mod util; diff --git a/core/embed/rust/src/micropython/buffer.rs b/core/embed/rust/src/micropython/buffer.rs index 49930cbf1..cac8d81e5 100644 --- a/core/embed/rust/src/micropython/buffer.rs +++ b/core/embed/rust/src/micropython/buffer.rs @@ -1,17 +1,22 @@ -use core::{convert::TryFrom, ops::Deref, ptr, slice}; +use core::{ + convert::TryFrom, + ops::{Deref, DerefMut}, + ptr, slice, +}; use crate::{error::Error, micropython::obj::Obj}; use super::ffi; /// Represents an immutable slice of bytes stored on the MicroPython heap and -/// owned by values that obey the buffer protocol, such as `bytes`, `str`, -/// `bytearray` or `memoryview`. +/// owned by values that obey the `MP_BUFFER_READ` buffer protocol, such as +/// `bytes`, `str`, `bytearray` or `memoryview`. /// /// # Safety /// /// In most cases, it is unsound to store `Buffer` values in a GC-unreachable -/// location, such as static data. +/// location, such as static data. It is also unsound to let the contents be +/// modified while a reference to them is being held. pub struct Buffer { ptr: *const u8, len: usize, @@ -21,24 +26,12 @@ impl TryFrom for Buffer { type Error = Error; fn try_from(obj: Obj) -> Result { - let mut bufinfo = ffi::mp_buffer_info_t { - buf: ptr::null_mut(), - len: 0, - typecode: 0, - }; - // SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully, - // `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Here - // we consider this data either GC-allocated or effectively 'static, and - // store the pointer directly in `Buffer`. It is unsound to store - // `Buffer` values in a GC-unreachable location, such as static data. - if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, ffi::MP_BUFFER_READ as _) } { - Ok(Self { - ptr: bufinfo.buf as _, - len: bufinfo.len as _, - }) - } else { - Err(Error::NotBuffer) - } + let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_READ)?; + + Ok(Self { + ptr: bufinfo.buf as _, + len: bufinfo.len as _, + }) } } @@ -52,15 +45,102 @@ impl Deref for Buffer { impl AsRef<[u8]> for Buffer { fn as_ref(&self) -> &[u8] { - if self.ptr.is_null() { - // `self.ptr` can be null if len == 0. - &[] - } else { - // SAFETY: We assume that `self.ptr` is pointing to memory: - // - without any mutable references, - // - immutable for the whole lifetime of `&self`, - // - with at least `self.len` bytes. - unsafe { slice::from_raw_parts(self.ptr, self.len) } - } + buffer_as_ref(self.ptr, self.len) + } +} + +/// Represents a mutable slice of bytes stored on the MicroPython heap and +/// owned by values that obey the `MP_BUFFER_WRITE` buffer protocol, such as +/// `bytearray` or `memoryview`. +/// +/// # Safety +/// +/// In most cases, it is unsound to store `Buffer` values in a GC-unreachable +/// location, such as static data. It is also unsound to let the contents be +/// modified while the reference to them is being held. +pub struct BufferMut { + ptr: *mut u8, + len: usize, +} + +impl TryFrom for BufferMut { + type Error = Error; + + fn try_from(obj: Obj) -> Result { + let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_WRITE)?; + + Ok(Self { + ptr: bufinfo.buf as _, + len: bufinfo.len as _, + }) + } +} + +impl Deref for BufferMut { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl DerefMut for BufferMut { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut() + } +} + +impl AsRef<[u8]> for BufferMut { + fn as_ref(&self) -> &[u8] { + buffer_as_ref(self.ptr, self.len) + } +} + +impl AsMut<[u8]> for BufferMut { + fn as_mut(&mut self) -> &mut [u8] { + buffer_as_mut(self.ptr, self.len) + } +} + +fn get_buffer_info(obj: Obj, flags: u32) -> Result { + let mut bufinfo = ffi::mp_buffer_info_t { + buf: ptr::null_mut(), + len: 0, + typecode: 0, + }; + // SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully, + // `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Later + // we consider this data either GC-allocated or effectively `'static`, embedding + // them in `Buffer`/`BufferMut`. + if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, flags as _) } { + Ok(bufinfo) + } else { + Err(Error::NotBuffer) + } +} + +fn buffer_as_ref<'a>(ptr: *const u8, len: usize) -> &'a [u8] { + if ptr.is_null() { + // `ptr` can be null if len == 0. + &[] + } else { + // SAFETY: We assume that `ptr` is pointing to memory: + // - without any mutable references, + // - valid and immutable in `'a`, + // - of at least `len` bytes. + unsafe { slice::from_raw_parts(ptr, len) } + } +} + +fn buffer_as_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] { + if ptr.is_null() { + // `ptr` can be null if len == 0. + &mut [] + } else { + // SAFETY: We assume that `ptr` is pointing to memory: + // - without any mutable references, + // - valid and mutable in `'a`, + // - of at least `len` bytes. + unsafe { slice::from_raw_parts_mut(ptr, len) } } } diff --git a/core/embed/rust/src/protobuf/decode.rs b/core/embed/rust/src/protobuf/decode.rs new file mode 100644 index 000000000..e95e2fc0a --- /dev/null +++ b/core/embed/rust/src/protobuf/decode.rs @@ -0,0 +1,313 @@ +use core::convert::{TryFrom, TryInto}; +use core::str; + +use crate::{ + error::Error, + micropython::{buffer::Buffer, gc::Gc, list::List, map::Map, obj::Obj, qstr::Qstr}, + util, +}; + +use super::{ + defs::{self, FieldDef, FieldType, MsgDef}, + obj::{MsgDefObj, MsgObj}, + zigzag, +}; + +#[no_mangle] +pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj { + util::try_or_raise(|| { + let name = Qstr::try_from(name)?; + let def = MsgDef::for_name(name.to_u16()).ok_or(Error::Missing)?; + let obj = MsgDefObj::alloc(def).into(); + Ok(obj) + }) +} + +#[no_mangle] +pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj { + util::try_or_raise(|| { + let wire_id = u16::try_from(wire_id)?; + let def = MsgDef::for_wire_id(wire_id).ok_or(Error::Missing)?; + let obj = MsgDefObj::alloc(def).into(); + Ok(obj) + }) +} + +#[no_mangle] +pub extern "C" fn protobuf_decode(buf: Obj, msg_def: Obj, enable_experimental: Obj) -> Obj { + util::try_or_raise(|| { + let buf = Buffer::try_from(buf)?; + let def = Gc::::try_from(msg_def)?; + let enable_experimental = bool::try_from(enable_experimental)?; + + if !enable_experimental && def.msg().is_experimental { + // Refuse to decode message defs marked as experimental if not + // explicitly allowed. Messages can also mark certain fields as + // experimental (not the whole message). This is enforced during the + // decoding. + return Err(Error::InvalidType); + } + + let stream = &mut InputStream::new(&buf); + let decoder = Decoder { + enable_experimental, + }; + + let obj = decoder.message_from_stream(stream, def.msg())?; + Ok(obj) + }) +} + +pub struct Decoder { + pub enable_experimental: bool, +} + +impl Decoder { + /// Create a new message instance and decode `stream` into it, handling the + /// default and required fields correctly. + pub fn message_from_stream( + &self, + stream: &mut InputStream, + msg: &MsgDef, + ) -> Result { + let mut obj = self.empty_message(msg); + // SAFETY: We assume that `obj` is not aliased here. + let map = unsafe { Gc::as_mut(&mut obj) }.map_mut(); + self.decode_fields_into(stream, msg, map)?; + self.decode_defaults_into(msg, map)?; + self.assign_required_into(msg, map)?; + Ok(obj.into()) + } + + /// Create a new message instance and fill it from `values`, handling the + /// default and required fields correctly. + pub fn message_from_values(&self, values: &Map, msg: &MsgDef) -> Result { + let mut obj = self.empty_message(msg); + // SAFETY: We assume that `obj` is not aliased here. + let map = unsafe { Gc::as_mut(&mut obj) }.map_mut(); + for elem in values.elems() { + map.set(elem.key, elem.value); + } + self.decode_defaults_into(msg, map)?; + self.assign_required_into(msg, map)?; + Ok(obj.into()) + } + + /// Allocate the backing message object with enough pre-allocated space for + /// all fields. + pub fn empty_message(&self, msg: &MsgDef) -> Gc { + MsgObj::alloc_with_capacity(msg.fields.len(), msg) + } + + /// Decode message fields one-by-one from the input stream, assigning them + /// into `map`. + fn decode_fields_into( + &self, + stream: &mut InputStream, + msg: &MsgDef, + map: &mut Map, + ) -> Result<(), Error> { + // Loop, trying to read the field key that contains the tag and primitive value + // type. If we fail to read the key, we are at the end of the stream. + while let Ok(field_key) = stream.read_uvarint() { + let field_tag = u8::try_from(field_key >> 3)?; + let prim_type = u8::try_from(field_key & 7)?; + + match msg.field(field_tag) { + Some(field) => { + let field_value = self.decode_field(stream, field)?; + let field_name = Qstr::from(field.name); + if field.is_repeated() { + // Repeated field, values are stored in a list. First, look up the list + // object. If it exists, append to it. If it doesn't, create a new list with + // this field's value and assign it. + if let Ok(obj) = map.get(field_name) { + let mut list = Gc::::try_from(obj)?; + // SAFETY: We assume that `list` is not aliased here. This holds for + // uses in `message_from_stream` and `message_from_values`, because we + // start with an empty `Map` and fill with unique lists. + unsafe { Gc::as_mut(&mut list) }.append(field_value); + } else { + let list = List::alloc(&[field_value]); + map.set(field_name, list); + } + } else { + // Singular field, assign the value directly. + map.set(field_name, field_value); + } + } + None => { + // Unknown field, skip it. + match prim_type { + defs::PRIMITIVE_TYPE_VARINT => { + stream.read_uvarint()?; + } + defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => { + let num = stream.read_uvarint()?; + let len = num.try_into()?; + stream.read(len)?; + } + _ => { + return Err(Error::InvalidType); + } + } + } + } + } + Ok(()) + } + + /// Fill in the default values by decoding them from the defaults stream. + /// Only singular fields are allowed to have a default value, this is + /// enforced in the blob compilation. + fn decode_defaults_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> { + let stream = &mut InputStream::new(msg.defaults); + + // The format of the defaults stream is a sequence of records: + // - one-byte field tag (without the primitive type). + // - Protobuf-encoded default value. + // We need to look to the field descriptor to know how to interpret the value + // after the field tag. + while let Ok(field_tag) = stream.read_byte() { + let field = msg.field(field_tag).ok_or(Error::Missing)?; + let field_name = Qstr::from(field.name); + if map.contains_key(field_name) { + // Field already has a value assigned, skip it. + match field.get_type().primitive_type() { + defs::PRIMITIVE_TYPE_VARINT => { + stream.read_uvarint()?; + } + defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => { + let num = stream.read_uvarint()?; + let len = num.try_into()?; + stream.read(len)?; + } + _ => { + return Err(Error::InvalidType); + } + } + } else { + // Decode the value and assign it. + let field_value = self.decode_field(stream, field)?; + map.set(field_name, field_value); + } + } + Ok(()) + } + + /// Walk the fields definitions and make sure that all required fields are + /// assigned and all optional missing fields are set to `None`. + fn assign_required_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> { + for field in msg.fields { + let field_name = Qstr::from(field.name); + if map.contains_key(field_name) { + // Field is assigned, skip. + continue; + } + if field.is_required() { + // Required field is missing, abort. + return Err(Error::Missing); + } + if field.is_repeated() { + // Optional repeated field, set to a new empty list. + map.set(field_name, List::alloc(&[])); + } else { + // Optional singular field, set to None. + map.set(field_name, Obj::const_none()); + } + } + Ok(()) + } + + /// Decode one field value from the input stream. + fn decode_field(&self, stream: &mut InputStream, field: &FieldDef) -> Result { + if field.is_experimental() && !self.enable_experimental { + return Err(Error::InvalidType); + } + let num = stream.read_uvarint()?; + match field.get_type() { + FieldType::UVarInt => Ok(num.into()), + FieldType::SVarInt => { + let signed_int = zigzag::to_signed(num); + Ok(signed_int.into()) + } + FieldType::Bool => { + let boolean = num != 0; + Ok(boolean.into()) + } + FieldType::Bytes => { + let buf_len = num.try_into()?; + let buf = stream.read(buf_len)?; + Ok(buf.into()) + } + FieldType::String => { + let buf_len = num.try_into()?; + let buf = stream.read(buf_len)?; + let unicode = str::from_utf8(buf).map_err(|_| Error::InvalidType)?; + Ok(unicode.into()) + } + FieldType::Enum(enum_type) => { + let enum_val = num.try_into()?; + if enum_type.values.contains(&enum_val) { + Ok(enum_val.into()) + } else { + Err(Error::InvalidType) + } + } + FieldType::Msg(msg_type) => { + let msg_len = num.try_into()?; + let sub_stream = &mut stream.read_stream(msg_len)?; + self.message_from_stream(sub_stream, &msg_type) + } + } + } +} + +pub struct InputStream<'a> { + buf: &'a [u8], + pos: usize, +} + +impl<'a> InputStream<'a> { + pub fn new(buf: &'a [u8]) -> Self { + Self { buf, pos: 0 } + } + + pub fn read_stream(&mut self, len: usize) -> Result { + let buf = self + .buf + .get(self.pos..self.pos + len) + .ok_or(Error::Missing)?; + self.pos += len; + Ok(Self::new(buf)) + } + + pub fn read(&mut self, len: usize) -> Result<&[u8], Error> { + let buf = self + .buf + .get(self.pos..self.pos + len) + .ok_or(Error::Missing)?; + self.pos += len; + Ok(buf) + } + + pub fn read_byte(&mut self) -> Result { + let val = self.buf.get(self.pos).copied().ok_or(Error::Missing)?; + self.pos += 1; + Ok(val) + } + + pub fn read_uvarint(&mut self) -> Result { + let mut uint = 0; + let mut shift = 0; + loop { + let byte = self.read_byte()?; + uint += (byte as u64 & 0x7F) << shift; + shift += 7; + if byte & 0x80 == 0 { + break; + } + } + Ok(uint) + } +} diff --git a/core/embed/rust/src/protobuf/defs.rs b/core/embed/rust/src/protobuf/defs.rs new file mode 100644 index 000000000..f66dc6e2a --- /dev/null +++ b/core/embed/rust/src/protobuf/defs.rs @@ -0,0 +1,224 @@ +use core::{mem, slice}; +use crate::error::Error; + +pub struct MsgDef { + pub fields: &'static [FieldDef], + pub defaults: &'static [u8], + pub is_experimental: bool, + pub wire_id: Option, + pub offset: u16, +} + +impl MsgDef { + pub fn for_name(msg_name: u16) -> Option { + find_msg_offset_by_name(msg_name).map(|msg_offset| unsafe { + // SAFETY: We are taking the offset right out of the definitions so we can be + // sure it's to be trusted. + get_msg(msg_offset) + }) + } + + pub fn for_wire_id(wire_id: u16) -> Option { + find_msg_offset_by_wire(wire_id).map(|msg_offset| unsafe { + // SAFETY: We are taking the offset right out of the definitions so we can be + // sure it's to be trusted. + get_msg(msg_offset) + }) + } + + pub fn field(&self, tag: u8) -> Option<&FieldDef> { + self.fields.iter().find(|field| field.tag == tag) + } +} + +#[repr(C, packed)] +pub struct FieldDef { + pub tag: u8, + flags_and_type: u8, + enum_or_msg_offset: u16, + pub name: u16, +} + +impl FieldDef { + pub fn get_type(&self) -> FieldType { + match self.ftype() { + 0 => FieldType::UVarInt, + 1 => FieldType::SVarInt, + 2 => FieldType::Bool, + 3 => FieldType::Bytes, + 4 => FieldType::String, + 5 => FieldType::Enum(unsafe { get_enum(self.enum_or_msg_offset) }), + 6 => FieldType::Msg(unsafe { get_msg(self.enum_or_msg_offset) }), + _ => unreachable!(), + } + } + + pub fn is_required(&self) -> bool { + self.flags() & 0b_1000_0000 != 0 + } + + pub fn is_repeated(&self) -> bool { + self.flags() & 0b_0100_0000 != 0 + } + + pub fn is_experimental(&self) -> bool { + self.flags() & 0b_0010_0000 != 0 + } + + fn flags(&self) -> u8 { + self.flags_and_type & 0xF0 + } + + fn ftype(&self) -> u8 { + self.flags_and_type & 0x0F + } +} + +pub enum FieldType { + UVarInt, + SVarInt, + Bool, + Bytes, + String, + Enum(EnumDef), + Msg(MsgDef), +} + +pub const PRIMITIVE_TYPE_VARINT: u8 = 0; +pub const PRIMITIVE_TYPE_LENGTH_DELIMITED: u8 = 2; + +impl FieldType { + pub fn primitive_type(&self) -> u8 { + match self { + FieldType::UVarInt | FieldType::SVarInt | FieldType::Bool | FieldType::Enum(_) => { + PRIMITIVE_TYPE_VARINT + } + FieldType::Bytes | FieldType::String | FieldType::Msg(_) => { + PRIMITIVE_TYPE_LENGTH_DELIMITED + } + } + } +} + +pub struct EnumDef { + pub values: &'static [u16], +} + +#[repr(C, packed)] +struct NameDef { + msg_name: u16, + msg_offset: u16, +} + +static ENUM_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../../proto_enums.data")); +static MSG_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_msgs.data")); +static NAME_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_names.data")); +static WIRE_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_wire.data")); + +pub fn find_name_by_msg_offset(msg_offset: u16) -> Result { + let name_defs: &[NameDef] = unsafe { + slice::from_raw_parts( + NAME_DEFS.as_ptr().cast(), + NAME_DEFS.len() / mem::size_of::(), + ) + }; + + name_defs.iter() + .filter(|def| def.msg_offset == msg_offset) + .next() + .map(|def| def.msg_name) + .ok_or(Error::Missing) +} + +fn find_msg_offset_by_name(msg_name: u16) -> Option { + let name_defs: &[NameDef] = unsafe { + slice::from_raw_parts( + NAME_DEFS.as_ptr().cast(), + NAME_DEFS.len() / mem::size_of::(), + ) + }; + name_defs + .binary_search_by_key(&msg_name, |def| def.msg_name) + .map(|i| name_defs[i].msg_offset) + .ok() +} + +fn find_msg_offset_by_wire(wire_id: u16) -> Option { + #[repr(C, packed)] + struct WireDef { + wire_id: u16, + msg_offset: u16, + } + + let wire_defs: &[WireDef] = unsafe { + slice::from_raw_parts( + WIRE_DEFS.as_ptr().cast(), + WIRE_DEFS.len() / mem::size_of::(), + ) + }; + wire_defs + .binary_search_by_key(&wire_id, |def| def.wire_id) + .map(|i| wire_defs[i].msg_offset) + .ok() +} + +pub unsafe fn get_msg(msg_offset: u16) -> MsgDef { + // #[repr(C, packed)] + // struct MsgDef { + // fields_count: u8, + // defaults_size: u8, + // flags_and_wire_id: u16, + // fields: [Field], + // defaults: [u8], + // } + + // SAFETY: `msg_offset` has to point to a beginning of a valid message + // definition inside `MSG_DEFS`. + unsafe { + let ptr = MSG_DEFS.as_ptr().add(msg_offset as usize); + let fields_count = ptr.offset(0).read() as usize; + let defaults_size = ptr.offset(1).read() as usize; + + let flags_and_wire_id_lo = ptr.offset(2).read(); + let flags_and_wire_id_hi = ptr.offset(3).read(); + let flags_and_wire_id = u16::from_le_bytes([flags_and_wire_id_lo, flags_and_wire_id_hi]); + + let is_experimental = flags_and_wire_id & 0x8000 != 0; + let wire_id = match flags_and_wire_id & 0x7FFF { + 0x7FFF => None, + some_wire_id => Some(some_wire_id), + }; + + let fields_size = fields_count * mem::size_of::(); + let fields_ptr = ptr.offset(4); + let defaults_ptr = ptr.offset(4).add(fields_size); + + MsgDef { + fields: slice::from_raw_parts(fields_ptr.cast(), fields_count), + defaults: slice::from_raw_parts(defaults_ptr.cast(), defaults_size), + is_experimental, + wire_id, + offset: msg_offset, + } + } +} + +unsafe fn get_enum(enum_offset: u16) -> EnumDef { + // #[repr(C, packed)] + // struct EnumDef { + // count: u8, + // vals: [u16], + // } + + // SAFETY: `enum_offset` has to point to a beginning of a valid enum + // definition inside `ENUM_DEFS`. + unsafe { + let ptr = ENUM_DEFS.as_ptr().add(enum_offset as usize); + let count = ptr.offset(0).read() as usize; + let vals = ptr.offset(1); + + EnumDef { + values: slice::from_raw_parts(vals.cast(), count), + } + } +} diff --git a/core/embed/rust/src/protobuf/encode.rs b/core/embed/rust/src/protobuf/encode.rs new file mode 100644 index 000000000..ca79b2c93 --- /dev/null +++ b/core/embed/rust/src/protobuf/encode.rs @@ -0,0 +1,236 @@ +use core::convert::TryFrom; + +use crate::{ + error::Error, + micropython::{ + buffer::{Buffer, BufferMut}, + gc::Gc, + iter::{Iter, IterBuf}, + list::List, + obj::Obj, + qstr::Qstr, + }, + util, +}; + +use super::{ + defs::{FieldDef, FieldType, MsgDef}, + obj::{MsgObj}, + zigzag, +}; + +#[no_mangle] +pub extern "C" fn protobuf_len(obj: Obj) -> Obj { + util::try_or_raise(|| { + let obj = Gc::::try_from(obj)?; + + let stream = &mut CounterStream { len: 0 }; + + Encoder.encode_message(stream, &obj.def(), &obj)?; + + Ok(stream.len.into()) + }) +} + +#[no_mangle] +pub extern "C" fn protobuf_encode(buf: Obj, obj: Obj) -> Obj { + util::try_or_raise(|| { + let obj = Gc::::try_from(obj)?; + + let buf = &mut BufferMut::try_from(buf)?; + let stream = &mut BufferStream::new(unsafe { + // SAFETY: We assume there are no other refs into `buf` at this point. This + // specifically means that no fields of `obj` should reference `buf` memory. + buf.as_mut() + }); + + Encoder.encode_message(stream, &obj.def(), &obj)?; + + Ok(stream.len().into()) + }) +} + +pub struct Encoder; + +impl Encoder { + pub fn encode_message( + &self, + stream: &mut impl OutputStream, + msg: &MsgDef, + obj: &MsgObj, + ) -> Result<(), Error> { + for field in msg.fields { + let field_name = Qstr::from(field.name); + + // Lookup the field by name. If not set or None, skip. + let field_value = match obj.map().get(field_name) { + Ok(value) => value, + Err(_) => continue, + }; + if field_value == Obj::const_none() { + continue; + } + + let field_key = { + let prim_type = field.get_type().primitive_type(); + let prim_type = prim_type as u64; + let field_tag = field.tag as u64; + field_tag << 3 | prim_type + }; + + if field.is_repeated() { + let mut iter_buf = IterBuf::new(); + let iter = Iter::try_from_obj_with_buf(field_value, &mut iter_buf)?; + for iter_value in iter { + stream.write_uvarint(field_key)?; + self.encode_field(stream, field, iter_value)?; + } + } else { + stream.write_uvarint(field_key)?; + self.encode_field(stream, field, field_value)?; + } + } + + Ok(()) + } + + pub fn encode_field( + &self, + stream: &mut impl OutputStream, + field: &FieldDef, + value: Obj, + ) -> Result<(), Error> { + match field.get_type() { + FieldType::UVarInt | FieldType::Enum(_) => { + let uint = u64::try_from(value)?; + stream.write_uvarint(uint)?; + } + FieldType::SVarInt => { + let sint = i64::try_from(value)?; + let uint = zigzag::to_unsigned(sint); + stream.write_uvarint(uint)?; + } + FieldType::Bool => { + let boolean = bool::try_from(value)?; + let uint = if boolean { 1 } else { 0 }; + stream.write_uvarint(uint)?; + } + FieldType::Bytes | FieldType::String => { + if Gc::::try_from(value).is_ok() { + // As an optimization, we support defining bytes and string + // fields also as a list-of-values. Serialize them as if + // concatenated. + + let mut iter_buf = IterBuf::new(); + + // Serialize the total length of the buffer. + let mut len = 0; + let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?; + for value in iter { + let buffer = Buffer::try_from(value)?; + len += buffer.len(); + } + stream.write_uvarint(len as u64)?; + + // Serialize the buffers one-by-one. + let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?; + for value in iter { + let buffer = Buffer::try_from(value)?; + stream.write(&buffer)?; + } + } else { + // Single length-delimited field. + let buffer = Buffer::try_from(value)?; + stream.write_uvarint(buffer.len() as u64)?; + stream.write(&buffer)?; + } + } + FieldType::Msg(msg_type) => { + let value = &Gc::::try_from(value)?; + // Calculate the message size by encoding it through `CountingWriter`. + let counter = &mut CounterStream { len: 0 }; + self.encode_message(counter, &msg_type, value)?; + + // Encode the message as length-delimited bytes. + stream.write_uvarint(counter.len as u64)?; + self.encode_message(stream, &msg_type, value)?; + } + } + + Ok(()) + } +} + +pub trait OutputStream { + fn write(&mut self, buf: &[u8]) -> Result<(), Error>; + fn write_byte(&mut self, val: u8) -> Result<(), Error>; + + fn write_uvarint(&mut self, mut num: u64) -> Result<(), Error> { + loop { + let shifted = num >> 7; + let byte = (num & 0x7F) as u8; + if shifted != 0 { + num = shifted; + self.write_byte(byte | 0x80)?; + } else { + break self.write_byte(byte); + } + } + } +} + +pub struct CounterStream { + pub len: usize, +} + +impl OutputStream for CounterStream { + fn write(&mut self, buf: &[u8]) -> Result<(), Error> { + self.len += buf.len(); + Ok(()) + } + + fn write_byte(&mut self, _val: u8) -> Result<(), Error> { + self.len += 1; + Ok(()) + } +} + +pub struct BufferStream<'a> { + buf: &'a mut [u8], + pos: usize, +} + +impl<'a> BufferStream<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { buf, pos: 0 } + } + + pub fn len(&self) -> usize { + self.pos + } +} + +impl<'a> OutputStream for BufferStream<'a> { + fn write(&mut self, val: &[u8]) -> Result<(), Error> { + let pos = &mut self.pos; + let len = val.len(); + self.buf + .get_mut(*pos..*pos + len) + .map(|buf| { + *pos += len; + buf.copy_from_slice(val); + }) + .ok_or(Error::Missing) + } + + fn write_byte(&mut self, val: u8) -> Result<(), Error> { + let pos = &mut self.pos; + self.buf + .get_mut(*pos) + .map(|buf| { + *pos += 1; + *buf = val; + }) + .ok_or(Error::Missing) + } +} diff --git a/core/embed/rust/src/protobuf/mod.rs b/core/embed/rust/src/protobuf/mod.rs new file mode 100644 index 000000000..6731fc447 --- /dev/null +++ b/core/embed/rust/src/protobuf/mod.rs @@ -0,0 +1,5 @@ +mod decode; +mod defs; +mod encode; +mod obj; +mod zigzag; diff --git a/core/embed/rust/src/protobuf/obj.rs b/core/embed/rust/src/protobuf/obj.rs new file mode 100644 index 000000000..cbb8faf36 --- /dev/null +++ b/core/embed/rust/src/protobuf/obj.rs @@ -0,0 +1,264 @@ +use core::convert::TryFrom; + +use crate::{ + error::Error, + micropython::{ + dict::Dict, + ffi, + gc::Gc, + map::Map, + obj::{Obj, ObjBase}, + qstr::Qstr, + typ::Type, + }, + util, +}; + +use super::decode::Decoder; +use super::defs::{find_name_by_msg_offset, get_msg, MsgDef}; + +#[repr(C)] +pub struct MsgObj { + base: ObjBase, + map: Map, + msg_wire_id: Option, + msg_offset: u16, +} + +impl MsgObj { + pub fn alloc_with_capacity(capacity: usize, msg: &MsgDef) -> Gc { + Gc::new(Self { + base: Self::obj_type().to_base(), + map: Map::with_capacity(capacity), + msg_wire_id: msg.wire_id, + msg_offset: msg.offset, + }) + } + + pub fn map(&self) -> &Map { + &self.map + } + + pub fn map_mut(&mut self) -> &mut Map { + &mut self.map + } + + pub fn def(&self) -> MsgDef { + unsafe { get_msg(self.msg_offset) } + } + + fn obj_type() -> &'static Type { + static TYPE: Type = obj_type! { + name: Qstr::MP_QSTR_Msg, + attr_fn: msg_obj_attr, + }; + &TYPE + } +} + +impl MsgObj { + fn getattr(&self, attr: Qstr) -> Result { + if let Ok(obj) = self.map.get(attr) { + // Message field was found, return its value. + return Ok(obj); + } + + // Built-in attribute + match attr { + Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => { + // Return the wire ID of this message def, or None if not set. + Ok(self.msg_wire_id.map_or(Obj::const_none(), |wire_id| wire_id.into())) + } + Qstr::MP_QSTR_MESSAGE_NAME => { + // Return the qstr name of this message def + Ok(Qstr::from_u16(find_name_by_msg_offset(self.msg_offset)?).into()) + } + Qstr::MP_QSTR___dict__ => { + // Conversion to dict. Allocate a new dict object with a copy of our map + // and return it. This is a bit different from how uPy does it now, because + // we're returning a mutable dict. + Ok(Gc::new(Dict::with_map(self.map.clone())).into()) + } + _ => { Err(Error::Missing) } + } + } + + fn setattr(&mut self, attr: Qstr, value: Obj) -> Result<(), Error> { + if value == Obj::const_null() { + // this would be a delattr + return Err(Error::InvalidOperation); + } + + if self.map.contains_key(attr) { + self.map.set(attr, value); + Ok(()) + } else { + Err(Error::Missing) + } + } +} + +impl Into for Gc { + fn into(self) -> Obj { + // SAFETY: + // - We are GC-allocated. + // - We are `repr(C)`. + // - We have a `base` as the first field with the correct type. + unsafe { Obj::from_ptr(Self::into_raw(self).cast()) } + } +} + +impl TryFrom for Gc { + type Error = Error; + + fn try_from(value: Obj) -> Result { + if MsgObj::obj_type().is_type_of(value) { + // SAFETY: We assume that if `value` is an object pointer with the correct type, + // it is always GC-allocated. + let this = unsafe { Gc::from_raw(value.as_ptr().cast()) }; + Ok(this) + } else { + Err(Error::InvalidType) + } + } +} + +unsafe extern "C" fn msg_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) { + util::try_or_raise(|| { + let mut this = Gc::::try_from(self_in)?; + let attr = Qstr::from_u16(attr as _); + + unsafe { + if dest.read() == Obj::const_null() { + // Load attribute + dest.write(this.getattr(attr)?); + } else { + let value = dest.offset(1).read(); + // Store attribute + Gc::as_mut(&mut this).setattr(attr, value)?; + dest.write(Obj::const_null()); + } + Ok(()) + } + }) +} + +#[repr(C)] +pub struct MsgDefObj { + base: ObjBase, + def: MsgDef, +} + +impl MsgDefObj { + pub fn alloc(def: MsgDef) -> Gc { + Gc::new(Self { + base: Self::obj_type().to_base(), + def, + }) + } + + pub fn msg(&self) -> &MsgDef { + &self.def + } + + fn obj_type() -> &'static Type { + static TYPE: Type = obj_type! { + name: Qstr::MP_QSTR_MsgDef, + attr_fn: msg_def_obj_attr, + call_fn: msg_def_obj_call, + }; + &TYPE + } +} + +impl Into for Gc { + fn into(self) -> Obj { + // SAFETY: + // - We are GC-allocated. + // - We are `repr(C)`. + // - We have a `base` as the first field with the correct type. + unsafe { Obj::from_ptr(Self::into_raw(self).cast()) } + } +} + +impl TryFrom for Gc { + type Error = Error; + + fn try_from(value: Obj) -> Result { + if MsgDefObj::obj_type().is_type_of(value) { + // SAFETY: We assume that if `value` is an object pointer with the correct type, + // it is always GC-allocated. + let this = unsafe { Gc::from_raw(value.as_ptr().cast()) }; + Ok(this) + } else { + Err(Error::InvalidType) + } + } +} + +unsafe extern "C" fn msg_def_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) { + util::try_or_raise(|| { + let this= Gc::::try_from(self_in)?; + let attr = Qstr::from_u16(attr as _); + + if unsafe { dest.read() } != Obj::const_null() { + return Err(Error::InvalidOperation); + } + + match attr { + Qstr::MP_QSTR_MESSAGE_NAME => { + // Return the qstr name of this message def + let name = Qstr::from_u16(find_name_by_msg_offset(this.def.offset)?); + unsafe { dest.write(name.into()); }; + } + Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => { + // Return the wire type of this message def + let wire_id_obj = this + .def + .wire_id + .map_or_else(Obj::const_none, |wire_id| wire_id.into()); + unsafe { dest.write(wire_id_obj); }; + } + Qstr::MP_QSTR_is_type_of => { + // Return the is_type_of bound method + // dest[0] = function_obj + // dest[1] = self + unsafe { + dest.write(MSG_DEF_OBJ_IS_TYPE_OF_OBJ.to_obj()); + dest.offset(1).write(self_in); + } + } + _ => { return Err(Error::Missing); } + } + Ok(()) + }); +} + +unsafe extern "C" fn msg_def_obj_call( + self_in: Obj, + n_args: usize, + n_kw: usize, + args: *const Obj, +) -> Obj { + util::try_with_args_and_kwargs_inline(n_args, n_kw, args, |_args, kwargs| { + let this = Gc::::try_from(self_in)?; + let decoder = Decoder { + enable_experimental: true, + }; + let obj = decoder.message_from_values(kwargs, this.msg())?; + Ok(obj) + }) +} + +unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj { + util::try_or_raise(|| { + let this = Gc::::try_from(self_in)?; + let msg = Gc::::try_from(obj); + match msg { + Ok(msg) if msg.msg_offset == this.def.offset => Ok(Obj::const_true()), + _ => Ok(Obj::const_false()), + } + }) +} + +static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t = obj_fn_2!(msg_def_obj_is_type_of); diff --git a/core/embed/rust/src/protobuf/zigzag.rs b/core/embed/rust/src/protobuf/zigzag.rs new file mode 100644 index 000000000..10f2b96fb --- /dev/null +++ b/core/embed/rust/src/protobuf/zigzag.rs @@ -0,0 +1,9 @@ +// https://developers.google.com/protocol-buffers/docs/encoding#signed_integers + +pub fn to_unsigned(sint: i64) -> u64 { + ((sint << 1) ^ (sint >> 63)) as u64 +} + +pub fn to_signed(uint: u64) -> i64 { + ((uint >> 1) as i64) ^ (-((uint & 1) as i64)) +} diff --git a/core/embed/rust/src/util.rs b/core/embed/rust/src/util.rs index 0241012e2..92079fc79 100644 --- a/core/embed/rust/src/util.rs +++ b/core/embed/rust/src/util.rs @@ -9,7 +9,7 @@ use crate::{ }, }; -pub fn try_or_raise(func: impl FnOnce() -> Result) -> Obj { +pub fn try_or_raise(func: impl FnOnce() -> Result) -> T { func().unwrap_or_else(|err| raise_value_error(err.as_cstr())) } diff --git a/core/embed/unix/mpconfigport.h b/core/embed/unix/mpconfigport.h index 2c4e943d1..9901b06d8 100644 --- a/core/embed/unix/mpconfigport.h +++ b/core/embed/unix/mpconfigport.h @@ -195,6 +195,7 @@ extern const struct _mp_print_t mp_stderr_print; #define MICROPY_PY_TREZORIO (1) #define MICROPY_PY_TREZORUI (1) #define MICROPY_PY_TREZORUTILS (1) +#define MICROPY_PY_TREZORPROTO (1) #define MP_STATE_PORT MP_STATE_VM diff --git a/core/mocks/generated/trezorproto.pyi b/core/mocks/generated/trezorproto.pyi new file mode 100644 index 000000000..0ecc346f1 --- /dev/null +++ b/core/mocks/generated/trezorproto.pyi @@ -0,0 +1,34 @@ +from typing import * +from trezor.protobuf import MessageType +T = TypeVar("T", bound=MessageType) + + +# extmod/rustmods/modtrezorproto.c +def type_for_name(name: str) -> Type[MessageType]: + """Find the message definition for the given protobuf name.""" + + +# extmod/rustmods/modtrezorproto.c +def type_for_wire(wire_type: int) -> Type[MessageType]: + """Find the message definition for the given wire type (numeric + identifier).""" + + +# extmod/rustmods/modtrezorproto.c +def decode( + buffer: bytes, + msg_type: Type[T], + enable_experimental: bool, +) -> T: + """Decode data in the buffer into the specified message type.""" + + +# extmod/rustmods/modtrezorproto.c +def encoded_length(msg: MessageType) -> int: + """Calculate length of encoding of the specified message.""" + + +# extmod/rustmods/modtrezorproto.c +def encode(buffer: bytearray, msg: MessageType) -> int: + """Encode the message into the specified buffer. Return length of + encoding."""