mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-28 17:18:29 +00:00
feat(core): Add Rust Protobuf codec
This commit is contained in:
parent
f9d4be268e
commit
8a21e3fc73
@ -174,6 +174,11 @@ SOURCE_MOD += [
|
||||
'embed/extmod/modtrezorutils/modtrezorutils.c',
|
||||
]
|
||||
|
||||
# rust mods
|
||||
SOURCE_MOD += [
|
||||
'embed/extmod/rustmods/modtrezorproto.c',
|
||||
]
|
||||
|
||||
# modutime
|
||||
SOURCE_MOD += [
|
||||
'embed/firmware/modutime.c',
|
||||
|
@ -171,6 +171,11 @@ SOURCE_MOD += [
|
||||
'embed/extmod/modtrezorutils/modtrezorutils.c',
|
||||
]
|
||||
|
||||
# rust mods
|
||||
SOURCE_MOD += [
|
||||
'embed/extmod/rustmods/modtrezorproto.c',
|
||||
]
|
||||
|
||||
# modutime
|
||||
SOURCE_MOD += [
|
||||
'vendor/micropython/ports/unix/modtime.c',
|
||||
|
86
core/embed/extmod/rustmods/modtrezorproto.c
Normal file
86
core/embed/extmod/rustmods/modtrezorproto.c
Normal file
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* This file is part of the Trezor project, https://trezor.io/
|
||||
*
|
||||
* Copyright (c) SatoshiLabs
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#include "py/runtime.h"
|
||||
|
||||
#if MICROPY_PY_TREZORPROTO
|
||||
|
||||
#include "librust.h"
|
||||
|
||||
/// from trezor.protobuf import MessageType
|
||||
/// T = TypeVar("T", bound=MessageType)
|
||||
|
||||
/// def type_for_name(name: str) -> Type[MessageType]:
|
||||
/// """Find the message definition for the given protobuf name."""
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_name_obj,
|
||||
protobuf_type_for_name);
|
||||
|
||||
/// def type_for_wire(wire_type: int) -> Type[MessageType]:
|
||||
/// """Find the message definition for the given wire type (numeric
|
||||
/// identifier)."""
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_wire_obj,
|
||||
protobuf_type_for_wire);
|
||||
|
||||
/// def decode(
|
||||
/// buffer: bytes,
|
||||
/// msg_type: Type[T],
|
||||
/// enable_experimental: bool,
|
||||
/// ) -> T:
|
||||
/// """Decode data in the buffer into the specified message type."""
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorutils_protobuf_decode_obj,
|
||||
protobuf_decode);
|
||||
|
||||
/// def encoded_length(msg: MessageType) -> int:
|
||||
/// """Calculate length of encoding of the specified message."""
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_encoded_length_obj,
|
||||
protobuf_len);
|
||||
|
||||
/// def encode(buffer: bytearray, msg: MessageType) -> int:
|
||||
/// """Encode the message into the specified buffer. Return length of
|
||||
/// encoding."""
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_protobuf_encode_obj,
|
||||
protobuf_encode);
|
||||
|
||||
STATIC const mp_rom_map_elem_t mp_module_trezorproto_globals_table[] = {
|
||||
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorproto)},
|
||||
|
||||
{MP_ROM_QSTR(MP_QSTR_type_for_name),
|
||||
MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_name_obj)},
|
||||
{MP_ROM_QSTR(MP_QSTR_type_for_wire),
|
||||
MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_wire_obj)},
|
||||
{MP_ROM_QSTR(MP_QSTR_decode),
|
||||
MP_ROM_PTR(&mod_trezorutils_protobuf_decode_obj)},
|
||||
{MP_ROM_QSTR(MP_QSTR_encoded_length),
|
||||
MP_ROM_PTR(&mod_trezorutils_protobuf_encoded_length_obj)},
|
||||
{MP_ROM_QSTR(MP_QSTR_encode),
|
||||
MP_ROM_PTR(&mod_trezorutils_protobuf_encode_obj)},
|
||||
};
|
||||
|
||||
STATIC MP_DEFINE_CONST_DICT(mp_module_trezorproto_globals,
|
||||
mp_module_trezorproto_globals_table);
|
||||
|
||||
const mp_obj_module_t mp_module_trezorproto = {
|
||||
.base = {&mp_type_module},
|
||||
.globals = (mp_obj_dict_t *)&mp_module_trezorproto_globals,
|
||||
};
|
||||
|
||||
MP_REGISTER_MODULE(MP_QSTR_trezorproto, mp_module_trezorproto,
|
||||
MICROPY_PY_TREZORPROTO);
|
||||
|
||||
#endif // MICROPY_PY_TREZORPROTO
|
@ -156,6 +156,7 @@
|
||||
#define MICROPY_PY_TREZORIO (1)
|
||||
#define MICROPY_PY_TREZORUI (1)
|
||||
#define MICROPY_PY_TREZORUTILS (1)
|
||||
#define MICROPY_PY_TREZORPROTO (1)
|
||||
|
||||
#ifdef SYSTEM_VIEW
|
||||
#define MP_PLAT_PRINT_STRN(str, len) segger_print(str, len)
|
||||
|
@ -1,2 +1,2 @@
|
||||
[build]
|
||||
target-dir = "../../build/rust"
|
||||
target-dir = "../../build/unix/rust"
|
||||
|
8
core/embed/rust/librust.h
Normal file
8
core/embed/rust/librust.h
Normal file
@ -0,0 +1,8 @@
|
||||
#include "librust_qstr.h"
|
||||
|
||||
mp_obj_t protobuf_type_for_name(mp_obj_t name);
|
||||
mp_obj_t protobuf_type_for_wire(mp_obj_t wire_id);
|
||||
mp_obj_t protobuf_decode(mp_obj_t buf, mp_obj_t def,
|
||||
mp_obj_t enable_experimental);
|
||||
mp_obj_t protobuf_len(mp_obj_t obj);
|
||||
mp_obj_t protobuf_encode(mp_obj_t buf, mp_obj_t obj);
|
11
core/embed/rust/librust_qstr.h
Normal file
11
core/embed/rust/librust_qstr.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma GCC diagnostic ignored "-Wunused-value"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
|
||||
static void _librust_qstrs(void) {
|
||||
// protobuf
|
||||
MP_QSTR_Msg;
|
||||
MP_QSTR_MsgDef;
|
||||
MP_QSTR_is_type_of;
|
||||
MP_QSTR_MESSAGE_WIRE_TYPE;
|
||||
MP_QSTR_MESSAGE_NAME;
|
||||
}
|
@ -9,6 +9,7 @@ pub enum Error {
|
||||
InvalidType,
|
||||
NotBuffer,
|
||||
NotInt,
|
||||
InvalidOperation,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
@ -22,6 +23,7 @@ impl Error {
|
||||
Error::InvalidType => cstr("InvalidType\0"),
|
||||
Error::NotBuffer => cstr("NotBuffer\0"),
|
||||
Error::NotInt => cstr("NotInt\0"),
|
||||
Error::InvalidOperation => cstr("InvalidOperation\0"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@
|
||||
mod error;
|
||||
#[macro_use]
|
||||
mod micropython;
|
||||
mod protobuf;
|
||||
mod trezorhal;
|
||||
mod util;
|
||||
|
||||
|
@ -1,17 +1,22 @@
|
||||
use core::{convert::TryFrom, ops::Deref, ptr, slice};
|
||||
use core::{
|
||||
convert::TryFrom,
|
||||
ops::{Deref, DerefMut},
|
||||
ptr, slice,
|
||||
};
|
||||
|
||||
use crate::{error::Error, micropython::obj::Obj};
|
||||
|
||||
use super::ffi;
|
||||
|
||||
/// Represents an immutable slice of bytes stored on the MicroPython heap and
|
||||
/// owned by values that obey the buffer protocol, such as `bytes`, `str`,
|
||||
/// `bytearray` or `memoryview`.
|
||||
/// owned by values that obey the `MP_BUFFER_READ` buffer protocol, such as
|
||||
/// `bytes`, `str`, `bytearray` or `memoryview`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// In most cases, it is unsound to store `Buffer` values in a GC-unreachable
|
||||
/// location, such as static data.
|
||||
/// location, such as static data. It is also unsound to let the contents be
|
||||
/// modified while a reference to them is being held.
|
||||
pub struct Buffer {
|
||||
ptr: *const u8,
|
||||
len: usize,
|
||||
@ -21,24 +26,12 @@ impl TryFrom<Obj> for Buffer {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(obj: Obj) -> Result<Self, Self::Error> {
|
||||
let mut bufinfo = ffi::mp_buffer_info_t {
|
||||
buf: ptr::null_mut(),
|
||||
len: 0,
|
||||
typecode: 0,
|
||||
};
|
||||
// SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully,
|
||||
// `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Here
|
||||
// we consider this data either GC-allocated or effectively 'static, and
|
||||
// store the pointer directly in `Buffer`. It is unsound to store
|
||||
// `Buffer` values in a GC-unreachable location, such as static data.
|
||||
if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, ffi::MP_BUFFER_READ as _) } {
|
||||
let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_READ)?;
|
||||
|
||||
Ok(Self {
|
||||
ptr: bufinfo.buf as _,
|
||||
len: bufinfo.len as _,
|
||||
})
|
||||
} else {
|
||||
Err(Error::NotBuffer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,15 +45,102 @@ impl Deref for Buffer {
|
||||
|
||||
impl AsRef<[u8]> for Buffer {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
if self.ptr.is_null() {
|
||||
// `self.ptr` can be null if len == 0.
|
||||
buffer_as_ref(self.ptr, self.len)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a mutable slice of bytes stored on the MicroPython heap and
|
||||
/// owned by values that obey the `MP_BUFFER_WRITE` buffer protocol, such as
|
||||
/// `bytearray` or `memoryview`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// In most cases, it is unsound to store `Buffer` values in a GC-unreachable
|
||||
/// location, such as static data. It is also unsound to let the contents be
|
||||
/// modified while the reference to them is being held.
|
||||
pub struct BufferMut {
|
||||
ptr: *mut u8,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl TryFrom<Obj> for BufferMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(obj: Obj) -> Result<Self, Self::Error> {
|
||||
let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_WRITE)?;
|
||||
|
||||
Ok(Self {
|
||||
ptr: bufinfo.buf as _,
|
||||
len: bufinfo.len as _,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for BufferMut {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for BufferMut {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.as_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for BufferMut {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
buffer_as_ref(self.ptr, self.len)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<[u8]> for BufferMut {
|
||||
fn as_mut(&mut self) -> &mut [u8] {
|
||||
buffer_as_mut(self.ptr, self.len)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_buffer_info(obj: Obj, flags: u32) -> Result<ffi::mp_buffer_info_t, Error> {
|
||||
let mut bufinfo = ffi::mp_buffer_info_t {
|
||||
buf: ptr::null_mut(),
|
||||
len: 0,
|
||||
typecode: 0,
|
||||
};
|
||||
// SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully,
|
||||
// `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Later
|
||||
// we consider this data either GC-allocated or effectively `'static`, embedding
|
||||
// them in `Buffer`/`BufferMut`.
|
||||
if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, flags as _) } {
|
||||
Ok(bufinfo)
|
||||
} else {
|
||||
Err(Error::NotBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_as_ref<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
|
||||
if ptr.is_null() {
|
||||
// `ptr` can be null if len == 0.
|
||||
&[]
|
||||
} else {
|
||||
// SAFETY: We assume that `self.ptr` is pointing to memory:
|
||||
// SAFETY: We assume that `ptr` is pointing to memory:
|
||||
// - without any mutable references,
|
||||
// - immutable for the whole lifetime of `&self`,
|
||||
// - with at least `self.len` bytes.
|
||||
unsafe { slice::from_raw_parts(self.ptr, self.len) }
|
||||
// - valid and immutable in `'a`,
|
||||
// - of at least `len` bytes.
|
||||
unsafe { slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_as_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
|
||||
if ptr.is_null() {
|
||||
// `ptr` can be null if len == 0.
|
||||
&mut []
|
||||
} else {
|
||||
// SAFETY: We assume that `ptr` is pointing to memory:
|
||||
// - without any mutable references,
|
||||
// - valid and mutable in `'a`,
|
||||
// - of at least `len` bytes.
|
||||
unsafe { slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
313
core/embed/rust/src/protobuf/decode.rs
Normal file
313
core/embed/rust/src/protobuf/decode.rs
Normal file
@ -0,0 +1,313 @@
|
||||
use core::convert::{TryFrom, TryInto};
|
||||
use core::str;
|
||||
|
||||
use crate::{
|
||||
error::Error,
|
||||
micropython::{buffer::Buffer, gc::Gc, list::List, map::Map, obj::Obj, qstr::Qstr},
|
||||
util,
|
||||
};
|
||||
|
||||
use super::{
|
||||
defs::{self, FieldDef, FieldType, MsgDef},
|
||||
obj::{MsgDefObj, MsgObj},
|
||||
zigzag,
|
||||
};
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let name = Qstr::try_from(name)?;
|
||||
let def = MsgDef::for_name(name.to_u16()).ok_or(Error::Missing)?;
|
||||
let obj = MsgDefObj::alloc(def).into();
|
||||
Ok(obj)
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let wire_id = u16::try_from(wire_id)?;
|
||||
let def = MsgDef::for_wire_id(wire_id).ok_or(Error::Missing)?;
|
||||
let obj = MsgDefObj::alloc(def).into();
|
||||
Ok(obj)
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_decode(buf: Obj, msg_def: Obj, enable_experimental: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let buf = Buffer::try_from(buf)?;
|
||||
let def = Gc::<MsgDefObj>::try_from(msg_def)?;
|
||||
let enable_experimental = bool::try_from(enable_experimental)?;
|
||||
|
||||
if !enable_experimental && def.msg().is_experimental {
|
||||
// Refuse to decode message defs marked as experimental if not
|
||||
// explicitly allowed. Messages can also mark certain fields as
|
||||
// experimental (not the whole message). This is enforced during the
|
||||
// decoding.
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
|
||||
let stream = &mut InputStream::new(&buf);
|
||||
let decoder = Decoder {
|
||||
enable_experimental,
|
||||
};
|
||||
|
||||
let obj = decoder.message_from_stream(stream, def.msg())?;
|
||||
Ok(obj)
|
||||
})
|
||||
}
|
||||
|
||||
pub struct Decoder {
|
||||
pub enable_experimental: bool,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
/// Create a new message instance and decode `stream` into it, handling the
|
||||
/// default and required fields correctly.
|
||||
pub fn message_from_stream(
|
||||
&self,
|
||||
stream: &mut InputStream,
|
||||
msg: &MsgDef,
|
||||
) -> Result<Obj, Error> {
|
||||
let mut obj = self.empty_message(msg);
|
||||
// SAFETY: We assume that `obj` is not aliased here.
|
||||
let map = unsafe { Gc::as_mut(&mut obj) }.map_mut();
|
||||
self.decode_fields_into(stream, msg, map)?;
|
||||
self.decode_defaults_into(msg, map)?;
|
||||
self.assign_required_into(msg, map)?;
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Create a new message instance and fill it from `values`, handling the
|
||||
/// default and required fields correctly.
|
||||
pub fn message_from_values(&self, values: &Map, msg: &MsgDef) -> Result<Obj, Error> {
|
||||
let mut obj = self.empty_message(msg);
|
||||
// SAFETY: We assume that `obj` is not aliased here.
|
||||
let map = unsafe { Gc::as_mut(&mut obj) }.map_mut();
|
||||
for elem in values.elems() {
|
||||
map.set(elem.key, elem.value);
|
||||
}
|
||||
self.decode_defaults_into(msg, map)?;
|
||||
self.assign_required_into(msg, map)?;
|
||||
Ok(obj.into())
|
||||
}
|
||||
|
||||
/// Allocate the backing message object with enough pre-allocated space for
|
||||
/// all fields.
|
||||
pub fn empty_message(&self, msg: &MsgDef) -> Gc<MsgObj> {
|
||||
MsgObj::alloc_with_capacity(msg.fields.len(), msg)
|
||||
}
|
||||
|
||||
/// Decode message fields one-by-one from the input stream, assigning them
|
||||
/// into `map`.
|
||||
fn decode_fields_into(
|
||||
&self,
|
||||
stream: &mut InputStream,
|
||||
msg: &MsgDef,
|
||||
map: &mut Map,
|
||||
) -> Result<(), Error> {
|
||||
// Loop, trying to read the field key that contains the tag and primitive value
|
||||
// type. If we fail to read the key, we are at the end of the stream.
|
||||
while let Ok(field_key) = stream.read_uvarint() {
|
||||
let field_tag = u8::try_from(field_key >> 3)?;
|
||||
let prim_type = u8::try_from(field_key & 7)?;
|
||||
|
||||
match msg.field(field_tag) {
|
||||
Some(field) => {
|
||||
let field_value = self.decode_field(stream, field)?;
|
||||
let field_name = Qstr::from(field.name);
|
||||
if field.is_repeated() {
|
||||
// Repeated field, values are stored in a list. First, look up the list
|
||||
// object. If it exists, append to it. If it doesn't, create a new list with
|
||||
// this field's value and assign it.
|
||||
if let Ok(obj) = map.get(field_name) {
|
||||
let mut list = Gc::<List>::try_from(obj)?;
|
||||
// SAFETY: We assume that `list` is not aliased here. This holds for
|
||||
// uses in `message_from_stream` and `message_from_values`, because we
|
||||
// start with an empty `Map` and fill with unique lists.
|
||||
unsafe { Gc::as_mut(&mut list) }.append(field_value);
|
||||
} else {
|
||||
let list = List::alloc(&[field_value]);
|
||||
map.set(field_name, list);
|
||||
}
|
||||
} else {
|
||||
// Singular field, assign the value directly.
|
||||
map.set(field_name, field_value);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Unknown field, skip it.
|
||||
match prim_type {
|
||||
defs::PRIMITIVE_TYPE_VARINT => {
|
||||
stream.read_uvarint()?;
|
||||
}
|
||||
defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => {
|
||||
let num = stream.read_uvarint()?;
|
||||
let len = num.try_into()?;
|
||||
stream.read(len)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fill in the default values by decoding them from the defaults stream.
|
||||
/// Only singular fields are allowed to have a default value, this is
|
||||
/// enforced in the blob compilation.
|
||||
fn decode_defaults_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> {
|
||||
let stream = &mut InputStream::new(msg.defaults);
|
||||
|
||||
// The format of the defaults stream is a sequence of records:
|
||||
// - one-byte field tag (without the primitive type).
|
||||
// - Protobuf-encoded default value.
|
||||
// We need to look to the field descriptor to know how to interpret the value
|
||||
// after the field tag.
|
||||
while let Ok(field_tag) = stream.read_byte() {
|
||||
let field = msg.field(field_tag).ok_or(Error::Missing)?;
|
||||
let field_name = Qstr::from(field.name);
|
||||
if map.contains_key(field_name) {
|
||||
// Field already has a value assigned, skip it.
|
||||
match field.get_type().primitive_type() {
|
||||
defs::PRIMITIVE_TYPE_VARINT => {
|
||||
stream.read_uvarint()?;
|
||||
}
|
||||
defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => {
|
||||
let num = stream.read_uvarint()?;
|
||||
let len = num.try_into()?;
|
||||
stream.read(len)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Decode the value and assign it.
|
||||
let field_value = self.decode_field(stream, field)?;
|
||||
map.set(field_name, field_value);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Walk the fields definitions and make sure that all required fields are
|
||||
/// assigned and all optional missing fields are set to `None`.
|
||||
fn assign_required_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> {
|
||||
for field in msg.fields {
|
||||
let field_name = Qstr::from(field.name);
|
||||
if map.contains_key(field_name) {
|
||||
// Field is assigned, skip.
|
||||
continue;
|
||||
}
|
||||
if field.is_required() {
|
||||
// Required field is missing, abort.
|
||||
return Err(Error::Missing);
|
||||
}
|
||||
if field.is_repeated() {
|
||||
// Optional repeated field, set to a new empty list.
|
||||
map.set(field_name, List::alloc(&[]));
|
||||
} else {
|
||||
// Optional singular field, set to None.
|
||||
map.set(field_name, Obj::const_none());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Decode one field value from the input stream.
|
||||
fn decode_field(&self, stream: &mut InputStream, field: &FieldDef) -> Result<Obj, Error> {
|
||||
if field.is_experimental() && !self.enable_experimental {
|
||||
return Err(Error::InvalidType);
|
||||
}
|
||||
let num = stream.read_uvarint()?;
|
||||
match field.get_type() {
|
||||
FieldType::UVarInt => Ok(num.into()),
|
||||
FieldType::SVarInt => {
|
||||
let signed_int = zigzag::to_signed(num);
|
||||
Ok(signed_int.into())
|
||||
}
|
||||
FieldType::Bool => {
|
||||
let boolean = num != 0;
|
||||
Ok(boolean.into())
|
||||
}
|
||||
FieldType::Bytes => {
|
||||
let buf_len = num.try_into()?;
|
||||
let buf = stream.read(buf_len)?;
|
||||
Ok(buf.into())
|
||||
}
|
||||
FieldType::String => {
|
||||
let buf_len = num.try_into()?;
|
||||
let buf = stream.read(buf_len)?;
|
||||
let unicode = str::from_utf8(buf).map_err(|_| Error::InvalidType)?;
|
||||
Ok(unicode.into())
|
||||
}
|
||||
FieldType::Enum(enum_type) => {
|
||||
let enum_val = num.try_into()?;
|
||||
if enum_type.values.contains(&enum_val) {
|
||||
Ok(enum_val.into())
|
||||
} else {
|
||||
Err(Error::InvalidType)
|
||||
}
|
||||
}
|
||||
FieldType::Msg(msg_type) => {
|
||||
let msg_len = num.try_into()?;
|
||||
let sub_stream = &mut stream.read_stream(msg_len)?;
|
||||
self.message_from_stream(sub_stream, &msg_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InputStream<'a> {
|
||||
buf: &'a [u8],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> InputStream<'a> {
|
||||
pub fn new(buf: &'a [u8]) -> Self {
|
||||
Self { buf, pos: 0 }
|
||||
}
|
||||
|
||||
pub fn read_stream(&mut self, len: usize) -> Result<Self, Error> {
|
||||
let buf = self
|
||||
.buf
|
||||
.get(self.pos..self.pos + len)
|
||||
.ok_or(Error::Missing)?;
|
||||
self.pos += len;
|
||||
Ok(Self::new(buf))
|
||||
}
|
||||
|
||||
pub fn read(&mut self, len: usize) -> Result<&[u8], Error> {
|
||||
let buf = self
|
||||
.buf
|
||||
.get(self.pos..self.pos + len)
|
||||
.ok_or(Error::Missing)?;
|
||||
self.pos += len;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
pub fn read_byte(&mut self) -> Result<u8, Error> {
|
||||
let val = self.buf.get(self.pos).copied().ok_or(Error::Missing)?;
|
||||
self.pos += 1;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
pub fn read_uvarint(&mut self) -> Result<u64, Error> {
|
||||
let mut uint = 0;
|
||||
let mut shift = 0;
|
||||
loop {
|
||||
let byte = self.read_byte()?;
|
||||
uint += (byte as u64 & 0x7F) << shift;
|
||||
shift += 7;
|
||||
if byte & 0x80 == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(uint)
|
||||
}
|
||||
}
|
224
core/embed/rust/src/protobuf/defs.rs
Normal file
224
core/embed/rust/src/protobuf/defs.rs
Normal file
@ -0,0 +1,224 @@
|
||||
use core::{mem, slice};
|
||||
use crate::error::Error;
|
||||
|
||||
pub struct MsgDef {
|
||||
pub fields: &'static [FieldDef],
|
||||
pub defaults: &'static [u8],
|
||||
pub is_experimental: bool,
|
||||
pub wire_id: Option<u16>,
|
||||
pub offset: u16,
|
||||
}
|
||||
|
||||
impl MsgDef {
|
||||
pub fn for_name(msg_name: u16) -> Option<Self> {
|
||||
find_msg_offset_by_name(msg_name).map(|msg_offset| unsafe {
|
||||
// SAFETY: We are taking the offset right out of the definitions so we can be
|
||||
// sure it's to be trusted.
|
||||
get_msg(msg_offset)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn for_wire_id(wire_id: u16) -> Option<Self> {
|
||||
find_msg_offset_by_wire(wire_id).map(|msg_offset| unsafe {
|
||||
// SAFETY: We are taking the offset right out of the definitions so we can be
|
||||
// sure it's to be trusted.
|
||||
get_msg(msg_offset)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn field(&self, tag: u8) -> Option<&FieldDef> {
|
||||
self.fields.iter().find(|field| field.tag == tag)
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
pub struct FieldDef {
|
||||
pub tag: u8,
|
||||
flags_and_type: u8,
|
||||
enum_or_msg_offset: u16,
|
||||
pub name: u16,
|
||||
}
|
||||
|
||||
impl FieldDef {
|
||||
pub fn get_type(&self) -> FieldType {
|
||||
match self.ftype() {
|
||||
0 => FieldType::UVarInt,
|
||||
1 => FieldType::SVarInt,
|
||||
2 => FieldType::Bool,
|
||||
3 => FieldType::Bytes,
|
||||
4 => FieldType::String,
|
||||
5 => FieldType::Enum(unsafe { get_enum(self.enum_or_msg_offset) }),
|
||||
6 => FieldType::Msg(unsafe { get_msg(self.enum_or_msg_offset) }),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_required(&self) -> bool {
|
||||
self.flags() & 0b_1000_0000 != 0
|
||||
}
|
||||
|
||||
pub fn is_repeated(&self) -> bool {
|
||||
self.flags() & 0b_0100_0000 != 0
|
||||
}
|
||||
|
||||
pub fn is_experimental(&self) -> bool {
|
||||
self.flags() & 0b_0010_0000 != 0
|
||||
}
|
||||
|
||||
fn flags(&self) -> u8 {
|
||||
self.flags_and_type & 0xF0
|
||||
}
|
||||
|
||||
fn ftype(&self) -> u8 {
|
||||
self.flags_and_type & 0x0F
|
||||
}
|
||||
}
|
||||
|
||||
pub enum FieldType {
|
||||
UVarInt,
|
||||
SVarInt,
|
||||
Bool,
|
||||
Bytes,
|
||||
String,
|
||||
Enum(EnumDef),
|
||||
Msg(MsgDef),
|
||||
}
|
||||
|
||||
pub const PRIMITIVE_TYPE_VARINT: u8 = 0;
|
||||
pub const PRIMITIVE_TYPE_LENGTH_DELIMITED: u8 = 2;
|
||||
|
||||
impl FieldType {
|
||||
pub fn primitive_type(&self) -> u8 {
|
||||
match self {
|
||||
FieldType::UVarInt | FieldType::SVarInt | FieldType::Bool | FieldType::Enum(_) => {
|
||||
PRIMITIVE_TYPE_VARINT
|
||||
}
|
||||
FieldType::Bytes | FieldType::String | FieldType::Msg(_) => {
|
||||
PRIMITIVE_TYPE_LENGTH_DELIMITED
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EnumDef {
|
||||
pub values: &'static [u16],
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
struct NameDef {
|
||||
msg_name: u16,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
static ENUM_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../../proto_enums.data"));
|
||||
static MSG_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_msgs.data"));
|
||||
static NAME_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_names.data"));
|
||||
static WIRE_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_wire.data"));
|
||||
|
||||
pub fn find_name_by_msg_offset(msg_offset: u16) -> Result<u16, Error> {
|
||||
let name_defs: &[NameDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
NAME_DEFS.as_ptr().cast(),
|
||||
NAME_DEFS.len() / mem::size_of::<NameDef>(),
|
||||
)
|
||||
};
|
||||
|
||||
name_defs.iter()
|
||||
.filter(|def| def.msg_offset == msg_offset)
|
||||
.next()
|
||||
.map(|def| def.msg_name)
|
||||
.ok_or(Error::Missing)
|
||||
}
|
||||
|
||||
fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
|
||||
let name_defs: &[NameDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
NAME_DEFS.as_ptr().cast(),
|
||||
NAME_DEFS.len() / mem::size_of::<NameDef>(),
|
||||
)
|
||||
};
|
||||
name_defs
|
||||
.binary_search_by_key(&msg_name, |def| def.msg_name)
|
||||
.map(|i| name_defs[i].msg_offset)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn find_msg_offset_by_wire(wire_id: u16) -> Option<u16> {
|
||||
#[repr(C, packed)]
|
||||
struct WireDef {
|
||||
wire_id: u16,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
let wire_defs: &[WireDef] = unsafe {
|
||||
slice::from_raw_parts(
|
||||
WIRE_DEFS.as_ptr().cast(),
|
||||
WIRE_DEFS.len() / mem::size_of::<WireDef>(),
|
||||
)
|
||||
};
|
||||
wire_defs
|
||||
.binary_search_by_key(&wire_id, |def| def.wire_id)
|
||||
.map(|i| wire_defs[i].msg_offset)
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
|
||||
// #[repr(C, packed)]
|
||||
// struct MsgDef {
|
||||
// fields_count: u8,
|
||||
// defaults_size: u8,
|
||||
// flags_and_wire_id: u16,
|
||||
// fields: [Field],
|
||||
// defaults: [u8],
|
||||
// }
|
||||
|
||||
// SAFETY: `msg_offset` has to point to a beginning of a valid message
|
||||
// definition inside `MSG_DEFS`.
|
||||
unsafe {
|
||||
let ptr = MSG_DEFS.as_ptr().add(msg_offset as usize);
|
||||
let fields_count = ptr.offset(0).read() as usize;
|
||||
let defaults_size = ptr.offset(1).read() as usize;
|
||||
|
||||
let flags_and_wire_id_lo = ptr.offset(2).read();
|
||||
let flags_and_wire_id_hi = ptr.offset(3).read();
|
||||
let flags_and_wire_id = u16::from_le_bytes([flags_and_wire_id_lo, flags_and_wire_id_hi]);
|
||||
|
||||
let is_experimental = flags_and_wire_id & 0x8000 != 0;
|
||||
let wire_id = match flags_and_wire_id & 0x7FFF {
|
||||
0x7FFF => None,
|
||||
some_wire_id => Some(some_wire_id),
|
||||
};
|
||||
|
||||
let fields_size = fields_count * mem::size_of::<FieldDef>();
|
||||
let fields_ptr = ptr.offset(4);
|
||||
let defaults_ptr = ptr.offset(4).add(fields_size);
|
||||
|
||||
MsgDef {
|
||||
fields: slice::from_raw_parts(fields_ptr.cast(), fields_count),
|
||||
defaults: slice::from_raw_parts(defaults_ptr.cast(), defaults_size),
|
||||
is_experimental,
|
||||
wire_id,
|
||||
offset: msg_offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_enum(enum_offset: u16) -> EnumDef {
|
||||
// #[repr(C, packed)]
|
||||
// struct EnumDef {
|
||||
// count: u8,
|
||||
// vals: [u16],
|
||||
// }
|
||||
|
||||
// SAFETY: `enum_offset` has to point to a beginning of a valid enum
|
||||
// definition inside `ENUM_DEFS`.
|
||||
unsafe {
|
||||
let ptr = ENUM_DEFS.as_ptr().add(enum_offset as usize);
|
||||
let count = ptr.offset(0).read() as usize;
|
||||
let vals = ptr.offset(1);
|
||||
|
||||
EnumDef {
|
||||
values: slice::from_raw_parts(vals.cast(), count),
|
||||
}
|
||||
}
|
||||
}
|
236
core/embed/rust/src/protobuf/encode.rs
Normal file
236
core/embed/rust/src/protobuf/encode.rs
Normal file
@ -0,0 +1,236 @@
|
||||
use core::convert::TryFrom;
|
||||
|
||||
use crate::{
|
||||
error::Error,
|
||||
micropython::{
|
||||
buffer::{Buffer, BufferMut},
|
||||
gc::Gc,
|
||||
iter::{Iter, IterBuf},
|
||||
list::List,
|
||||
obj::Obj,
|
||||
qstr::Qstr,
|
||||
},
|
||||
util,
|
||||
};
|
||||
|
||||
use super::{
|
||||
defs::{FieldDef, FieldType, MsgDef},
|
||||
obj::{MsgObj},
|
||||
zigzag,
|
||||
};
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_len(obj: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let obj = Gc::<MsgObj>::try_from(obj)?;
|
||||
|
||||
let stream = &mut CounterStream { len: 0 };
|
||||
|
||||
Encoder.encode_message(stream, &obj.def(), &obj)?;
|
||||
|
||||
Ok(stream.len.into())
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn protobuf_encode(buf: Obj, obj: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let obj = Gc::<MsgObj>::try_from(obj)?;
|
||||
|
||||
let buf = &mut BufferMut::try_from(buf)?;
|
||||
let stream = &mut BufferStream::new(unsafe {
|
||||
// SAFETY: We assume there are no other refs into `buf` at this point. This
|
||||
// specifically means that no fields of `obj` should reference `buf` memory.
|
||||
buf.as_mut()
|
||||
});
|
||||
|
||||
Encoder.encode_message(stream, &obj.def(), &obj)?;
|
||||
|
||||
Ok(stream.len().into())
|
||||
})
|
||||
}
|
||||
|
||||
pub struct Encoder;
|
||||
|
||||
impl Encoder {
|
||||
pub fn encode_message(
|
||||
&self,
|
||||
stream: &mut impl OutputStream,
|
||||
msg: &MsgDef,
|
||||
obj: &MsgObj,
|
||||
) -> Result<(), Error> {
|
||||
for field in msg.fields {
|
||||
let field_name = Qstr::from(field.name);
|
||||
|
||||
// Lookup the field by name. If not set or None, skip.
|
||||
let field_value = match obj.map().get(field_name) {
|
||||
Ok(value) => value,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if field_value == Obj::const_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let field_key = {
|
||||
let prim_type = field.get_type().primitive_type();
|
||||
let prim_type = prim_type as u64;
|
||||
let field_tag = field.tag as u64;
|
||||
field_tag << 3 | prim_type
|
||||
};
|
||||
|
||||
if field.is_repeated() {
|
||||
let mut iter_buf = IterBuf::new();
|
||||
let iter = Iter::try_from_obj_with_buf(field_value, &mut iter_buf)?;
|
||||
for iter_value in iter {
|
||||
stream.write_uvarint(field_key)?;
|
||||
self.encode_field(stream, field, iter_value)?;
|
||||
}
|
||||
} else {
|
||||
stream.write_uvarint(field_key)?;
|
||||
self.encode_field(stream, field, field_value)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn encode_field(
|
||||
&self,
|
||||
stream: &mut impl OutputStream,
|
||||
field: &FieldDef,
|
||||
value: Obj,
|
||||
) -> Result<(), Error> {
|
||||
match field.get_type() {
|
||||
FieldType::UVarInt | FieldType::Enum(_) => {
|
||||
let uint = u64::try_from(value)?;
|
||||
stream.write_uvarint(uint)?;
|
||||
}
|
||||
FieldType::SVarInt => {
|
||||
let sint = i64::try_from(value)?;
|
||||
let uint = zigzag::to_unsigned(sint);
|
||||
stream.write_uvarint(uint)?;
|
||||
}
|
||||
FieldType::Bool => {
|
||||
let boolean = bool::try_from(value)?;
|
||||
let uint = if boolean { 1 } else { 0 };
|
||||
stream.write_uvarint(uint)?;
|
||||
}
|
||||
FieldType::Bytes | FieldType::String => {
|
||||
if Gc::<List>::try_from(value).is_ok() {
|
||||
// As an optimization, we support defining bytes and string
|
||||
// fields also as a list-of-values. Serialize them as if
|
||||
// concatenated.
|
||||
|
||||
let mut iter_buf = IterBuf::new();
|
||||
|
||||
// Serialize the total length of the buffer.
|
||||
let mut len = 0;
|
||||
let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?;
|
||||
for value in iter {
|
||||
let buffer = Buffer::try_from(value)?;
|
||||
len += buffer.len();
|
||||
}
|
||||
stream.write_uvarint(len as u64)?;
|
||||
|
||||
// Serialize the buffers one-by-one.
|
||||
let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?;
|
||||
for value in iter {
|
||||
let buffer = Buffer::try_from(value)?;
|
||||
stream.write(&buffer)?;
|
||||
}
|
||||
} else {
|
||||
// Single length-delimited field.
|
||||
let buffer = Buffer::try_from(value)?;
|
||||
stream.write_uvarint(buffer.len() as u64)?;
|
||||
stream.write(&buffer)?;
|
||||
}
|
||||
}
|
||||
FieldType::Msg(msg_type) => {
|
||||
let value = &Gc::<MsgObj>::try_from(value)?;
|
||||
// Calculate the message size by encoding it through `CountingWriter`.
|
||||
let counter = &mut CounterStream { len: 0 };
|
||||
self.encode_message(counter, &msg_type, value)?;
|
||||
|
||||
// Encode the message as length-delimited bytes.
|
||||
stream.write_uvarint(counter.len as u64)?;
|
||||
self.encode_message(stream, &msg_type, value)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait OutputStream {
|
||||
fn write(&mut self, buf: &[u8]) -> Result<(), Error>;
|
||||
fn write_byte(&mut self, val: u8) -> Result<(), Error>;
|
||||
|
||||
fn write_uvarint(&mut self, mut num: u64) -> Result<(), Error> {
|
||||
loop {
|
||||
let shifted = num >> 7;
|
||||
let byte = (num & 0x7F) as u8;
|
||||
if shifted != 0 {
|
||||
num = shifted;
|
||||
self.write_byte(byte | 0x80)?;
|
||||
} else {
|
||||
break self.write_byte(byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CounterStream {
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
impl OutputStream for CounterStream {
|
||||
fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
|
||||
self.len += buf.len();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_byte(&mut self, _val: u8) -> Result<(), Error> {
|
||||
self.len += 1;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferStream<'a> {
|
||||
buf: &'a mut [u8],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> BufferStream<'a> {
|
||||
pub fn new(buf: &'a mut [u8]) -> Self {
|
||||
Self { buf, pos: 0 }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> OutputStream for BufferStream<'a> {
|
||||
fn write(&mut self, val: &[u8]) -> Result<(), Error> {
|
||||
let pos = &mut self.pos;
|
||||
let len = val.len();
|
||||
self.buf
|
||||
.get_mut(*pos..*pos + len)
|
||||
.map(|buf| {
|
||||
*pos += len;
|
||||
buf.copy_from_slice(val);
|
||||
})
|
||||
.ok_or(Error::Missing)
|
||||
}
|
||||
|
||||
fn write_byte(&mut self, val: u8) -> Result<(), Error> {
|
||||
let pos = &mut self.pos;
|
||||
self.buf
|
||||
.get_mut(*pos)
|
||||
.map(|buf| {
|
||||
*pos += 1;
|
||||
*buf = val;
|
||||
})
|
||||
.ok_or(Error::Missing)
|
||||
}
|
||||
}
|
5
core/embed/rust/src/protobuf/mod.rs
Normal file
5
core/embed/rust/src/protobuf/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
mod decode;
|
||||
mod defs;
|
||||
mod encode;
|
||||
mod obj;
|
||||
mod zigzag;
|
264
core/embed/rust/src/protobuf/obj.rs
Normal file
264
core/embed/rust/src/protobuf/obj.rs
Normal file
@ -0,0 +1,264 @@
|
||||
use core::convert::TryFrom;
|
||||
|
||||
use crate::{
|
||||
error::Error,
|
||||
micropython::{
|
||||
dict::Dict,
|
||||
ffi,
|
||||
gc::Gc,
|
||||
map::Map,
|
||||
obj::{Obj, ObjBase},
|
||||
qstr::Qstr,
|
||||
typ::Type,
|
||||
},
|
||||
util,
|
||||
};
|
||||
|
||||
use super::decode::Decoder;
|
||||
use super::defs::{find_name_by_msg_offset, get_msg, MsgDef};
|
||||
|
||||
#[repr(C)]
|
||||
pub struct MsgObj {
|
||||
base: ObjBase,
|
||||
map: Map,
|
||||
msg_wire_id: Option<u16>,
|
||||
msg_offset: u16,
|
||||
}
|
||||
|
||||
impl MsgObj {
|
||||
pub fn alloc_with_capacity(capacity: usize, msg: &MsgDef) -> Gc<Self> {
|
||||
Gc::new(Self {
|
||||
base: Self::obj_type().to_base(),
|
||||
map: Map::with_capacity(capacity),
|
||||
msg_wire_id: msg.wire_id,
|
||||
msg_offset: msg.offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map(&self) -> &Map {
|
||||
&self.map
|
||||
}
|
||||
|
||||
pub fn map_mut(&mut self) -> &mut Map {
|
||||
&mut self.map
|
||||
}
|
||||
|
||||
pub fn def(&self) -> MsgDef {
|
||||
unsafe { get_msg(self.msg_offset) }
|
||||
}
|
||||
|
||||
fn obj_type() -> &'static Type {
|
||||
static TYPE: Type = obj_type! {
|
||||
name: Qstr::MP_QSTR_Msg,
|
||||
attr_fn: msg_obj_attr,
|
||||
};
|
||||
&TYPE
|
||||
}
|
||||
}
|
||||
|
||||
impl MsgObj {
|
||||
fn getattr(&self, attr: Qstr) -> Result<Obj, Error> {
|
||||
if let Ok(obj) = self.map.get(attr) {
|
||||
// Message field was found, return its value.
|
||||
return Ok(obj);
|
||||
}
|
||||
|
||||
// Built-in attribute
|
||||
match attr {
|
||||
Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => {
|
||||
// Return the wire ID of this message def, or None if not set.
|
||||
Ok(self.msg_wire_id.map_or(Obj::const_none(), |wire_id| wire_id.into()))
|
||||
}
|
||||
Qstr::MP_QSTR_MESSAGE_NAME => {
|
||||
// Return the qstr name of this message def
|
||||
Ok(Qstr::from_u16(find_name_by_msg_offset(self.msg_offset)?).into())
|
||||
}
|
||||
Qstr::MP_QSTR___dict__ => {
|
||||
// Conversion to dict. Allocate a new dict object with a copy of our map
|
||||
// and return it. This is a bit different from how uPy does it now, because
|
||||
// we're returning a mutable dict.
|
||||
Ok(Gc::new(Dict::with_map(self.map.clone())).into())
|
||||
}
|
||||
_ => { Err(Error::Missing) }
|
||||
}
|
||||
}
|
||||
|
||||
fn setattr(&mut self, attr: Qstr, value: Obj) -> Result<(), Error> {
|
||||
if value == Obj::const_null() {
|
||||
// this would be a delattr
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
|
||||
if self.map.contains_key(attr) {
|
||||
self.map.set(attr, value);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Missing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Obj> for Gc<MsgObj> {
|
||||
fn into(self) -> Obj {
|
||||
// SAFETY:
|
||||
// - We are GC-allocated.
|
||||
// - We are `repr(C)`.
|
||||
// - We have a `base` as the first field with the correct type.
|
||||
unsafe { Obj::from_ptr(Self::into_raw(self).cast()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Obj> for Gc<MsgObj> {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: Obj) -> Result<Self, Self::Error> {
|
||||
if MsgObj::obj_type().is_type_of(value) {
|
||||
// SAFETY: We assume that if `value` is an object pointer with the correct type,
|
||||
// it is always GC-allocated.
|
||||
let this = unsafe { Gc::from_raw(value.as_ptr().cast()) };
|
||||
Ok(this)
|
||||
} else {
|
||||
Err(Error::InvalidType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn msg_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) {
|
||||
util::try_or_raise(|| {
|
||||
let mut this = Gc::<MsgObj>::try_from(self_in)?;
|
||||
let attr = Qstr::from_u16(attr as _);
|
||||
|
||||
unsafe {
|
||||
if dest.read() == Obj::const_null() {
|
||||
// Load attribute
|
||||
dest.write(this.getattr(attr)?);
|
||||
} else {
|
||||
let value = dest.offset(1).read();
|
||||
// Store attribute
|
||||
Gc::as_mut(&mut this).setattr(attr, value)?;
|
||||
dest.write(Obj::const_null());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct MsgDefObj {
|
||||
base: ObjBase,
|
||||
def: MsgDef,
|
||||
}
|
||||
|
||||
impl MsgDefObj {
|
||||
pub fn alloc(def: MsgDef) -> Gc<Self> {
|
||||
Gc::new(Self {
|
||||
base: Self::obj_type().to_base(),
|
||||
def,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn msg(&self) -> &MsgDef {
|
||||
&self.def
|
||||
}
|
||||
|
||||
fn obj_type() -> &'static Type {
|
||||
static TYPE: Type = obj_type! {
|
||||
name: Qstr::MP_QSTR_MsgDef,
|
||||
attr_fn: msg_def_obj_attr,
|
||||
call_fn: msg_def_obj_call,
|
||||
};
|
||||
&TYPE
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Obj> for Gc<MsgDefObj> {
|
||||
fn into(self) -> Obj {
|
||||
// SAFETY:
|
||||
// - We are GC-allocated.
|
||||
// - We are `repr(C)`.
|
||||
// - We have a `base` as the first field with the correct type.
|
||||
unsafe { Obj::from_ptr(Self::into_raw(self).cast()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Obj> for Gc<MsgDefObj> {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: Obj) -> Result<Self, Self::Error> {
|
||||
if MsgDefObj::obj_type().is_type_of(value) {
|
||||
// SAFETY: We assume that if `value` is an object pointer with the correct type,
|
||||
// it is always GC-allocated.
|
||||
let this = unsafe { Gc::from_raw(value.as_ptr().cast()) };
|
||||
Ok(this)
|
||||
} else {
|
||||
Err(Error::InvalidType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn msg_def_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) {
|
||||
util::try_or_raise(|| {
|
||||
let this= Gc::<MsgDefObj>::try_from(self_in)?;
|
||||
let attr = Qstr::from_u16(attr as _);
|
||||
|
||||
if unsafe { dest.read() } != Obj::const_null() {
|
||||
return Err(Error::InvalidOperation);
|
||||
}
|
||||
|
||||
match attr {
|
||||
Qstr::MP_QSTR_MESSAGE_NAME => {
|
||||
// Return the qstr name of this message def
|
||||
let name = Qstr::from_u16(find_name_by_msg_offset(this.def.offset)?);
|
||||
unsafe { dest.write(name.into()); };
|
||||
}
|
||||
Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => {
|
||||
// Return the wire type of this message def
|
||||
let wire_id_obj = this
|
||||
.def
|
||||
.wire_id
|
||||
.map_or_else(Obj::const_none, |wire_id| wire_id.into());
|
||||
unsafe { dest.write(wire_id_obj); };
|
||||
}
|
||||
Qstr::MP_QSTR_is_type_of => {
|
||||
// Return the is_type_of bound method
|
||||
// dest[0] = function_obj
|
||||
// dest[1] = self
|
||||
unsafe {
|
||||
dest.write(MSG_DEF_OBJ_IS_TYPE_OF_OBJ.to_obj());
|
||||
dest.offset(1).write(self_in);
|
||||
}
|
||||
}
|
||||
_ => { return Err(Error::Missing); }
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
unsafe extern "C" fn msg_def_obj_call(
|
||||
self_in: Obj,
|
||||
n_args: usize,
|
||||
n_kw: usize,
|
||||
args: *const Obj,
|
||||
) -> Obj {
|
||||
util::try_with_args_and_kwargs_inline(n_args, n_kw, args, |_args, kwargs| {
|
||||
let this = Gc::<MsgDefObj>::try_from(self_in)?;
|
||||
let decoder = Decoder {
|
||||
enable_experimental: true,
|
||||
};
|
||||
let obj = decoder.message_from_values(kwargs, this.msg())?;
|
||||
Ok(obj)
|
||||
})
|
||||
}
|
||||
|
||||
unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj {
|
||||
util::try_or_raise(|| {
|
||||
let this = Gc::<MsgDefObj>::try_from(self_in)?;
|
||||
let msg = Gc::<MsgObj>::try_from(obj);
|
||||
match msg {
|
||||
Ok(msg) if msg.msg_offset == this.def.offset => Ok(Obj::const_true()),
|
||||
_ => Ok(Obj::const_false()),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t = obj_fn_2!(msg_def_obj_is_type_of);
|
9
core/embed/rust/src/protobuf/zigzag.rs
Normal file
9
core/embed/rust/src/protobuf/zigzag.rs
Normal file
@ -0,0 +1,9 @@
|
||||
// https://developers.google.com/protocol-buffers/docs/encoding#signed_integers
|
||||
|
||||
pub fn to_unsigned(sint: i64) -> u64 {
|
||||
((sint << 1) ^ (sint >> 63)) as u64
|
||||
}
|
||||
|
||||
pub fn to_signed(uint: u64) -> i64 {
|
||||
((uint >> 1) as i64) ^ (-((uint & 1) as i64))
|
||||
}
|
@ -9,7 +9,7 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
pub fn try_or_raise(func: impl FnOnce() -> Result<Obj, Error>) -> Obj {
|
||||
pub fn try_or_raise<T>(func: impl FnOnce() -> Result<T, Error>) -> T {
|
||||
func().unwrap_or_else(|err| raise_value_error(err.as_cstr()))
|
||||
}
|
||||
|
||||
|
@ -195,6 +195,7 @@ extern const struct _mp_print_t mp_stderr_print;
|
||||
#define MICROPY_PY_TREZORIO (1)
|
||||
#define MICROPY_PY_TREZORUI (1)
|
||||
#define MICROPY_PY_TREZORUTILS (1)
|
||||
#define MICROPY_PY_TREZORPROTO (1)
|
||||
|
||||
#define MP_STATE_PORT MP_STATE_VM
|
||||
|
||||
|
34
core/mocks/generated/trezorproto.pyi
Normal file
34
core/mocks/generated/trezorproto.pyi
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import *
|
||||
from trezor.protobuf import MessageType
|
||||
T = TypeVar("T", bound=MessageType)
|
||||
|
||||
|
||||
# extmod/rustmods/modtrezorproto.c
|
||||
def type_for_name(name: str) -> Type[MessageType]:
|
||||
"""Find the message definition for the given protobuf name."""
|
||||
|
||||
|
||||
# extmod/rustmods/modtrezorproto.c
|
||||
def type_for_wire(wire_type: int) -> Type[MessageType]:
|
||||
"""Find the message definition for the given wire type (numeric
|
||||
identifier)."""
|
||||
|
||||
|
||||
# extmod/rustmods/modtrezorproto.c
|
||||
def decode(
|
||||
buffer: bytes,
|
||||
msg_type: Type[T],
|
||||
enable_experimental: bool,
|
||||
) -> T:
|
||||
"""Decode data in the buffer into the specified message type."""
|
||||
|
||||
|
||||
# extmod/rustmods/modtrezorproto.c
|
||||
def encoded_length(msg: MessageType) -> int:
|
||||
"""Calculate length of encoding of the specified message."""
|
||||
|
||||
|
||||
# extmod/rustmods/modtrezorproto.c
|
||||
def encode(buffer: bytearray, msg: MessageType) -> int:
|
||||
"""Encode the message into the specified buffer. Return length of
|
||||
encoding."""
|
Loading…
Reference in New Issue
Block a user