feat(core): Add Rust Protobuf codec

pull/1557/head
Jan Pochyla 3 years ago committed by matejcik
parent f9d4be268e
commit 8a21e3fc73

@ -174,6 +174,11 @@ SOURCE_MOD += [
'embed/extmod/modtrezorutils/modtrezorutils.c',
]
# rust mods
SOURCE_MOD += [
'embed/extmod/rustmods/modtrezorproto.c',
]
# modutime
SOURCE_MOD += [
'embed/firmware/modutime.c',

@ -171,6 +171,11 @@ SOURCE_MOD += [
'embed/extmod/modtrezorutils/modtrezorutils.c',
]
# rust mods
SOURCE_MOD += [
'embed/extmod/rustmods/modtrezorproto.c',
]
# modutime
SOURCE_MOD += [
'vendor/micropython/ports/unix/modtime.c',

@ -0,0 +1,86 @@
/*
* This file is part of the Trezor project, https://trezor.io/
*
* Copyright (c) SatoshiLabs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "py/runtime.h"
#if MICROPY_PY_TREZORPROTO
#include "librust.h"
/// from trezor.protobuf import MessageType
/// T = TypeVar("T", bound=MessageType)
/// def type_for_name(name: str) -> Type[MessageType]:
/// """Find the message definition for the given protobuf name."""
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_name_obj,
protobuf_type_for_name);
/// def type_for_wire(wire_type: int) -> Type[MessageType]:
/// """Find the message definition for the given wire type (numeric
/// identifier)."""
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_type_for_wire_obj,
protobuf_type_for_wire);
/// def decode(
/// buffer: bytes,
/// msg_type: Type[T],
/// enable_experimental: bool,
/// ) -> T:
/// """Decode data in the buffer into the specified message type."""
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorutils_protobuf_decode_obj,
protobuf_decode);
/// def encoded_length(msg: MessageType) -> int:
/// """Calculate length of encoding of the specified message."""
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorutils_protobuf_encoded_length_obj,
protobuf_len);
/// def encode(buffer: bytearray, msg: MessageType) -> int:
/// """Encode the message into the specified buffer. Return length of
/// encoding."""
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_protobuf_encode_obj,
protobuf_encode);
STATIC const mp_rom_map_elem_t mp_module_trezorproto_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorproto)},
{MP_ROM_QSTR(MP_QSTR_type_for_name),
MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_name_obj)},
{MP_ROM_QSTR(MP_QSTR_type_for_wire),
MP_ROM_PTR(&mod_trezorutils_protobuf_type_for_wire_obj)},
{MP_ROM_QSTR(MP_QSTR_decode),
MP_ROM_PTR(&mod_trezorutils_protobuf_decode_obj)},
{MP_ROM_QSTR(MP_QSTR_encoded_length),
MP_ROM_PTR(&mod_trezorutils_protobuf_encoded_length_obj)},
{MP_ROM_QSTR(MP_QSTR_encode),
MP_ROM_PTR(&mod_trezorutils_protobuf_encode_obj)},
};
STATIC MP_DEFINE_CONST_DICT(mp_module_trezorproto_globals,
mp_module_trezorproto_globals_table);
const mp_obj_module_t mp_module_trezorproto = {
.base = {&mp_type_module},
.globals = (mp_obj_dict_t *)&mp_module_trezorproto_globals,
};
MP_REGISTER_MODULE(MP_QSTR_trezorproto, mp_module_trezorproto,
MICROPY_PY_TREZORPROTO);
#endif // MICROPY_PY_TREZORPROTO

@ -156,6 +156,7 @@
#define MICROPY_PY_TREZORIO (1)
#define MICROPY_PY_TREZORUI (1)
#define MICROPY_PY_TREZORUTILS (1)
#define MICROPY_PY_TREZORPROTO (1)
#ifdef SYSTEM_VIEW
#define MP_PLAT_PRINT_STRN(str, len) segger_print(str, len)

@ -1,2 +1,2 @@
[build]
target-dir = "../../build/rust"
target-dir = "../../build/unix/rust"

@ -0,0 +1,8 @@
#include "librust_qstr.h"
mp_obj_t protobuf_type_for_name(mp_obj_t name);
mp_obj_t protobuf_type_for_wire(mp_obj_t wire_id);
mp_obj_t protobuf_decode(mp_obj_t buf, mp_obj_t def,
mp_obj_t enable_experimental);
mp_obj_t protobuf_len(mp_obj_t obj);
mp_obj_t protobuf_encode(mp_obj_t buf, mp_obj_t obj);

@ -0,0 +1,11 @@
#pragma GCC diagnostic ignored "-Wunused-value"
#pragma GCC diagnostic ignored "-Wunused-function"
static void _librust_qstrs(void) {
// protobuf
MP_QSTR_Msg;
MP_QSTR_MsgDef;
MP_QSTR_is_type_of;
MP_QSTR_MESSAGE_WIRE_TYPE;
MP_QSTR_MESSAGE_NAME;
}

@ -9,6 +9,7 @@ pub enum Error {
InvalidType,
NotBuffer,
NotInt,
InvalidOperation,
}
impl Error {
@ -22,6 +23,7 @@ impl Error {
Error::InvalidType => cstr("InvalidType\0"),
Error::NotBuffer => cstr("NotBuffer\0"),
Error::NotInt => cstr("NotInt\0"),
Error::InvalidOperation => cstr("InvalidOperation\0"),
}
}
}

@ -6,6 +6,7 @@
mod error;
#[macro_use]
mod micropython;
mod protobuf;
mod trezorhal;
mod util;

@ -1,17 +1,22 @@
use core::{convert::TryFrom, ops::Deref, ptr, slice};
use core::{
convert::TryFrom,
ops::{Deref, DerefMut},
ptr, slice,
};
use crate::{error::Error, micropython::obj::Obj};
use super::ffi;
/// Represents an immutable slice of bytes stored on the MicroPython heap and
/// owned by values that obey the buffer protocol, such as `bytes`, `str`,
/// `bytearray` or `memoryview`.
/// owned by values that obey the `MP_BUFFER_READ` buffer protocol, such as
/// `bytes`, `str`, `bytearray` or `memoryview`.
///
/// # Safety
///
/// In most cases, it is unsound to store `Buffer` values in a GC-unreachable
/// location, such as static data.
/// location, such as static data. It is also unsound to let the contents be
/// modified while a reference to them is being held.
pub struct Buffer {
ptr: *const u8,
len: usize,
@ -21,24 +26,12 @@ impl TryFrom<Obj> for Buffer {
type Error = Error;
fn try_from(obj: Obj) -> Result<Self, Self::Error> {
let mut bufinfo = ffi::mp_buffer_info_t {
buf: ptr::null_mut(),
len: 0,
typecode: 0,
};
// SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully,
// `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Here
// we consider this data either GC-allocated or effectively 'static, and
// store the pointer directly in `Buffer`. It is unsound to store
// `Buffer` values in a GC-unreachable location, such as static data.
if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, ffi::MP_BUFFER_READ as _) } {
Ok(Self {
ptr: bufinfo.buf as _,
len: bufinfo.len as _,
})
} else {
Err(Error::NotBuffer)
}
let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_READ)?;
Ok(Self {
ptr: bufinfo.buf as _,
len: bufinfo.len as _,
})
}
}
@ -52,15 +45,102 @@ impl Deref for Buffer {
impl AsRef<[u8]> for Buffer {
fn as_ref(&self) -> &[u8] {
if self.ptr.is_null() {
// `self.ptr` can be null if len == 0.
&[]
} else {
// SAFETY: We assume that `self.ptr` is pointing to memory:
// - without any mutable references,
// - immutable for the whole lifetime of `&self`,
// - with at least `self.len` bytes.
unsafe { slice::from_raw_parts(self.ptr, self.len) }
}
buffer_as_ref(self.ptr, self.len)
}
}
/// Represents a mutable slice of bytes stored on the MicroPython heap and
/// owned by values that obey the `MP_BUFFER_WRITE` buffer protocol, such as
/// `bytearray` or `memoryview`.
///
/// # Safety
///
/// In most cases, it is unsound to store `Buffer` values in a GC-unreachable
/// location, such as static data. It is also unsound to let the contents be
/// modified while the reference to them is being held.
pub struct BufferMut {
ptr: *mut u8,
len: usize,
}
impl TryFrom<Obj> for BufferMut {
type Error = Error;
fn try_from(obj: Obj) -> Result<Self, Self::Error> {
let bufinfo = get_buffer_info(obj, ffi::MP_BUFFER_WRITE)?;
Ok(Self {
ptr: bufinfo.buf as _,
len: bufinfo.len as _,
})
}
}
impl Deref for BufferMut {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl DerefMut for BufferMut {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}
impl AsRef<[u8]> for BufferMut {
fn as_ref(&self) -> &[u8] {
buffer_as_ref(self.ptr, self.len)
}
}
impl AsMut<[u8]> for BufferMut {
fn as_mut(&mut self) -> &mut [u8] {
buffer_as_mut(self.ptr, self.len)
}
}
fn get_buffer_info(obj: Obj, flags: u32) -> Result<ffi::mp_buffer_info_t, Error> {
let mut bufinfo = ffi::mp_buffer_info_t {
buf: ptr::null_mut(),
len: 0,
typecode: 0,
};
// SAFETY: We assume that if `ffi::mp_get_buffer` returns successfully,
// `bufinfo.buf` contains a pointer to data of `bufinfo.len` bytes. Later
// we consider this data either GC-allocated or effectively `'static`, embedding
// them in `Buffer`/`BufferMut`.
if unsafe { ffi::mp_get_buffer(obj, &mut bufinfo, flags as _) } {
Ok(bufinfo)
} else {
Err(Error::NotBuffer)
}
}
fn buffer_as_ref<'a>(ptr: *const u8, len: usize) -> &'a [u8] {
if ptr.is_null() {
// `ptr` can be null if len == 0.
&[]
} else {
// SAFETY: We assume that `ptr` is pointing to memory:
// - without any mutable references,
// - valid and immutable in `'a`,
// - of at least `len` bytes.
unsafe { slice::from_raw_parts(ptr, len) }
}
}
fn buffer_as_mut<'a>(ptr: *mut u8, len: usize) -> &'a mut [u8] {
if ptr.is_null() {
// `ptr` can be null if len == 0.
&mut []
} else {
// SAFETY: We assume that `ptr` is pointing to memory:
// - without any mutable references,
// - valid and mutable in `'a`,
// - of at least `len` bytes.
unsafe { slice::from_raw_parts_mut(ptr, len) }
}
}

@ -0,0 +1,313 @@
use core::convert::{TryFrom, TryInto};
use core::str;
use crate::{
error::Error,
micropython::{buffer::Buffer, gc::Gc, list::List, map::Map, obj::Obj, qstr::Qstr},
util,
};
use super::{
defs::{self, FieldDef, FieldType, MsgDef},
obj::{MsgDefObj, MsgObj},
zigzag,
};
#[no_mangle]
pub extern "C" fn protobuf_type_for_name(name: Obj) -> Obj {
util::try_or_raise(|| {
let name = Qstr::try_from(name)?;
let def = MsgDef::for_name(name.to_u16()).ok_or(Error::Missing)?;
let obj = MsgDefObj::alloc(def).into();
Ok(obj)
})
}
#[no_mangle]
pub extern "C" fn protobuf_type_for_wire(wire_id: Obj) -> Obj {
util::try_or_raise(|| {
let wire_id = u16::try_from(wire_id)?;
let def = MsgDef::for_wire_id(wire_id).ok_or(Error::Missing)?;
let obj = MsgDefObj::alloc(def).into();
Ok(obj)
})
}
#[no_mangle]
pub extern "C" fn protobuf_decode(buf: Obj, msg_def: Obj, enable_experimental: Obj) -> Obj {
util::try_or_raise(|| {
let buf = Buffer::try_from(buf)?;
let def = Gc::<MsgDefObj>::try_from(msg_def)?;
let enable_experimental = bool::try_from(enable_experimental)?;
if !enable_experimental && def.msg().is_experimental {
// Refuse to decode message defs marked as experimental if not
// explicitly allowed. Messages can also mark certain fields as
// experimental (not the whole message). This is enforced during the
// decoding.
return Err(Error::InvalidType);
}
let stream = &mut InputStream::new(&buf);
let decoder = Decoder {
enable_experimental,
};
let obj = decoder.message_from_stream(stream, def.msg())?;
Ok(obj)
})
}
pub struct Decoder {
pub enable_experimental: bool,
}
impl Decoder {
/// Create a new message instance and decode `stream` into it, handling the
/// default and required fields correctly.
pub fn message_from_stream(
&self,
stream: &mut InputStream,
msg: &MsgDef,
) -> Result<Obj, Error> {
let mut obj = self.empty_message(msg);
// SAFETY: We assume that `obj` is not aliased here.
let map = unsafe { Gc::as_mut(&mut obj) }.map_mut();
self.decode_fields_into(stream, msg, map)?;
self.decode_defaults_into(msg, map)?;
self.assign_required_into(msg, map)?;
Ok(obj.into())
}
/// Create a new message instance and fill it from `values`, handling the
/// default and required fields correctly.
pub fn message_from_values(&self, values: &Map, msg: &MsgDef) -> Result<Obj, Error> {
let mut obj = self.empty_message(msg);
// SAFETY: We assume that `obj` is not aliased here.
let map = unsafe { Gc::as_mut(&mut obj) }.map_mut();
for elem in values.elems() {
map.set(elem.key, elem.value);
}
self.decode_defaults_into(msg, map)?;
self.assign_required_into(msg, map)?;
Ok(obj.into())
}
/// Allocate the backing message object with enough pre-allocated space for
/// all fields.
pub fn empty_message(&self, msg: &MsgDef) -> Gc<MsgObj> {
MsgObj::alloc_with_capacity(msg.fields.len(), msg)
}
/// Decode message fields one-by-one from the input stream, assigning them
/// into `map`.
fn decode_fields_into(
&self,
stream: &mut InputStream,
msg: &MsgDef,
map: &mut Map,
) -> Result<(), Error> {
// Loop, trying to read the field key that contains the tag and primitive value
// type. If we fail to read the key, we are at the end of the stream.
while let Ok(field_key) = stream.read_uvarint() {
let field_tag = u8::try_from(field_key >> 3)?;
let prim_type = u8::try_from(field_key & 7)?;
match msg.field(field_tag) {
Some(field) => {
let field_value = self.decode_field(stream, field)?;
let field_name = Qstr::from(field.name);
if field.is_repeated() {
// Repeated field, values are stored in a list. First, look up the list
// object. If it exists, append to it. If it doesn't, create a new list with
// this field's value and assign it.
if let Ok(obj) = map.get(field_name) {
let mut list = Gc::<List>::try_from(obj)?;
// SAFETY: We assume that `list` is not aliased here. This holds for
// uses in `message_from_stream` and `message_from_values`, because we
// start with an empty `Map` and fill with unique lists.
unsafe { Gc::as_mut(&mut list) }.append(field_value);
} else {
let list = List::alloc(&[field_value]);
map.set(field_name, list);
}
} else {
// Singular field, assign the value directly.
map.set(field_name, field_value);
}
}
None => {
// Unknown field, skip it.
match prim_type {
defs::PRIMITIVE_TYPE_VARINT => {
stream.read_uvarint()?;
}
defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => {
let num = stream.read_uvarint()?;
let len = num.try_into()?;
stream.read(len)?;
}
_ => {
return Err(Error::InvalidType);
}
}
}
}
}
Ok(())
}
/// Fill in the default values by decoding them from the defaults stream.
/// Only singular fields are allowed to have a default value, this is
/// enforced in the blob compilation.
fn decode_defaults_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> {
let stream = &mut InputStream::new(msg.defaults);
// The format of the defaults stream is a sequence of records:
// - one-byte field tag (without the primitive type).
// - Protobuf-encoded default value.
// We need to look to the field descriptor to know how to interpret the value
// after the field tag.
while let Ok(field_tag) = stream.read_byte() {
let field = msg.field(field_tag).ok_or(Error::Missing)?;
let field_name = Qstr::from(field.name);
if map.contains_key(field_name) {
// Field already has a value assigned, skip it.
match field.get_type().primitive_type() {
defs::PRIMITIVE_TYPE_VARINT => {
stream.read_uvarint()?;
}
defs::PRIMITIVE_TYPE_LENGTH_DELIMITED => {
let num = stream.read_uvarint()?;
let len = num.try_into()?;
stream.read(len)?;
}
_ => {
return Err(Error::InvalidType);
}
}
} else {
// Decode the value and assign it.
let field_value = self.decode_field(stream, field)?;
map.set(field_name, field_value);
}
}
Ok(())
}
/// Walk the fields definitions and make sure that all required fields are
/// assigned and all optional missing fields are set to `None`.
fn assign_required_into(&self, msg: &MsgDef, map: &mut Map) -> Result<(), Error> {
for field in msg.fields {
let field_name = Qstr::from(field.name);
if map.contains_key(field_name) {
// Field is assigned, skip.
continue;
}
if field.is_required() {
// Required field is missing, abort.
return Err(Error::Missing);
}
if field.is_repeated() {
// Optional repeated field, set to a new empty list.
map.set(field_name, List::alloc(&[]));
} else {
// Optional singular field, set to None.
map.set(field_name, Obj::const_none());
}
}
Ok(())
}
/// Decode one field value from the input stream.
fn decode_field(&self, stream: &mut InputStream, field: &FieldDef) -> Result<Obj, Error> {
if field.is_experimental() && !self.enable_experimental {
return Err(Error::InvalidType);
}
let num = stream.read_uvarint()?;
match field.get_type() {
FieldType::UVarInt => Ok(num.into()),
FieldType::SVarInt => {
let signed_int = zigzag::to_signed(num);
Ok(signed_int.into())
}
FieldType::Bool => {
let boolean = num != 0;
Ok(boolean.into())
}
FieldType::Bytes => {
let buf_len = num.try_into()?;
let buf = stream.read(buf_len)?;
Ok(buf.into())
}
FieldType::String => {
let buf_len = num.try_into()?;
let buf = stream.read(buf_len)?;
let unicode = str::from_utf8(buf).map_err(|_| Error::InvalidType)?;
Ok(unicode.into())
}
FieldType::Enum(enum_type) => {
let enum_val = num.try_into()?;
if enum_type.values.contains(&enum_val) {
Ok(enum_val.into())
} else {
Err(Error::InvalidType)
}
}
FieldType::Msg(msg_type) => {
let msg_len = num.try_into()?;
let sub_stream = &mut stream.read_stream(msg_len)?;
self.message_from_stream(sub_stream, &msg_type)
}
}
}
}
pub struct InputStream<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> InputStream<'a> {
pub fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
pub fn read_stream(&mut self, len: usize) -> Result<Self, Error> {
let buf = self
.buf
.get(self.pos..self.pos + len)
.ok_or(Error::Missing)?;
self.pos += len;
Ok(Self::new(buf))
}
pub fn read(&mut self, len: usize) -> Result<&[u8], Error> {
let buf = self
.buf
.get(self.pos..self.pos + len)
.ok_or(Error::Missing)?;
self.pos += len;
Ok(buf)
}
pub fn read_byte(&mut self) -> Result<u8, Error> {
let val = self.buf.get(self.pos).copied().ok_or(Error::Missing)?;
self.pos += 1;
Ok(val)
}
pub fn read_uvarint(&mut self) -> Result<u64, Error> {
let mut uint = 0;
let mut shift = 0;
loop {
let byte = self.read_byte()?;
uint += (byte as u64 & 0x7F) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
}
}
Ok(uint)
}
}

@ -0,0 +1,224 @@
use core::{mem, slice};
use crate::error::Error;
pub struct MsgDef {
pub fields: &'static [FieldDef],
pub defaults: &'static [u8],
pub is_experimental: bool,
pub wire_id: Option<u16>,
pub offset: u16,
}
impl MsgDef {
pub fn for_name(msg_name: u16) -> Option<Self> {
find_msg_offset_by_name(msg_name).map(|msg_offset| unsafe {
// SAFETY: We are taking the offset right out of the definitions so we can be
// sure it's to be trusted.
get_msg(msg_offset)
})
}
pub fn for_wire_id(wire_id: u16) -> Option<Self> {
find_msg_offset_by_wire(wire_id).map(|msg_offset| unsafe {
// SAFETY: We are taking the offset right out of the definitions so we can be
// sure it's to be trusted.
get_msg(msg_offset)
})
}
pub fn field(&self, tag: u8) -> Option<&FieldDef> {
self.fields.iter().find(|field| field.tag == tag)
}
}
#[repr(C, packed)]
pub struct FieldDef {
pub tag: u8,
flags_and_type: u8,
enum_or_msg_offset: u16,
pub name: u16,
}
impl FieldDef {
pub fn get_type(&self) -> FieldType {
match self.ftype() {
0 => FieldType::UVarInt,
1 => FieldType::SVarInt,
2 => FieldType::Bool,
3 => FieldType::Bytes,
4 => FieldType::String,
5 => FieldType::Enum(unsafe { get_enum(self.enum_or_msg_offset) }),
6 => FieldType::Msg(unsafe { get_msg(self.enum_or_msg_offset) }),
_ => unreachable!(),
}
}
pub fn is_required(&self) -> bool {
self.flags() & 0b_1000_0000 != 0
}
pub fn is_repeated(&self) -> bool {
self.flags() & 0b_0100_0000 != 0
}
pub fn is_experimental(&self) -> bool {
self.flags() & 0b_0010_0000 != 0
}
fn flags(&self) -> u8 {
self.flags_and_type & 0xF0
}
fn ftype(&self) -> u8 {
self.flags_and_type & 0x0F
}
}
pub enum FieldType {
UVarInt,
SVarInt,
Bool,
Bytes,
String,
Enum(EnumDef),
Msg(MsgDef),
}
pub const PRIMITIVE_TYPE_VARINT: u8 = 0;
pub const PRIMITIVE_TYPE_LENGTH_DELIMITED: u8 = 2;
impl FieldType {
pub fn primitive_type(&self) -> u8 {
match self {
FieldType::UVarInt | FieldType::SVarInt | FieldType::Bool | FieldType::Enum(_) => {
PRIMITIVE_TYPE_VARINT
}
FieldType::Bytes | FieldType::String | FieldType::Msg(_) => {
PRIMITIVE_TYPE_LENGTH_DELIMITED
}
}
}
}
pub struct EnumDef {
pub values: &'static [u16],
}
#[repr(C, packed)]
struct NameDef {
msg_name: u16,
msg_offset: u16,
}
static ENUM_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../../proto_enums.data"));
static MSG_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_msgs.data"));
static NAME_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_names.data"));
static WIRE_DEFS: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/../../../..//proto_wire.data"));
pub fn find_name_by_msg_offset(msg_offset: u16) -> Result<u16, Error> {
let name_defs: &[NameDef] = unsafe {
slice::from_raw_parts(
NAME_DEFS.as_ptr().cast(),
NAME_DEFS.len() / mem::size_of::<NameDef>(),
)
};
name_defs.iter()
.filter(|def| def.msg_offset == msg_offset)
.next()
.map(|def| def.msg_name)
.ok_or(Error::Missing)
}
fn find_msg_offset_by_name(msg_name: u16) -> Option<u16> {
let name_defs: &[NameDef] = unsafe {
slice::from_raw_parts(
NAME_DEFS.as_ptr().cast(),
NAME_DEFS.len() / mem::size_of::<NameDef>(),
)
};
name_defs
.binary_search_by_key(&msg_name, |def| def.msg_name)
.map(|i| name_defs[i].msg_offset)
.ok()
}
fn find_msg_offset_by_wire(wire_id: u16) -> Option<u16> {
#[repr(C, packed)]
struct WireDef {
wire_id: u16,
msg_offset: u16,
}
let wire_defs: &[WireDef] = unsafe {
slice::from_raw_parts(
WIRE_DEFS.as_ptr().cast(),
WIRE_DEFS.len() / mem::size_of::<WireDef>(),
)
};
wire_defs
.binary_search_by_key(&wire_id, |def| def.wire_id)
.map(|i| wire_defs[i].msg_offset)
.ok()
}
pub unsafe fn get_msg(msg_offset: u16) -> MsgDef {
// #[repr(C, packed)]
// struct MsgDef {
// fields_count: u8,
// defaults_size: u8,
// flags_and_wire_id: u16,
// fields: [Field],
// defaults: [u8],
// }
// SAFETY: `msg_offset` has to point to a beginning of a valid message
// definition inside `MSG_DEFS`.
unsafe {
let ptr = MSG_DEFS.as_ptr().add(msg_offset as usize);
let fields_count = ptr.offset(0).read() as usize;
let defaults_size = ptr.offset(1).read() as usize;
let flags_and_wire_id_lo = ptr.offset(2).read();
let flags_and_wire_id_hi = ptr.offset(3).read();
let flags_and_wire_id = u16::from_le_bytes([flags_and_wire_id_lo, flags_and_wire_id_hi]);
let is_experimental = flags_and_wire_id & 0x8000 != 0;
let wire_id = match flags_and_wire_id & 0x7FFF {
0x7FFF => None,
some_wire_id => Some(some_wire_id),
};
let fields_size = fields_count * mem::size_of::<FieldDef>();
let fields_ptr = ptr.offset(4);
let defaults_ptr = ptr.offset(4).add(fields_size);
MsgDef {
fields: slice::from_raw_parts(fields_ptr.cast(), fields_count),
defaults: slice::from_raw_parts(defaults_ptr.cast(), defaults_size),
is_experimental,
wire_id,
offset: msg_offset,
}
}
}
unsafe fn get_enum(enum_offset: u16) -> EnumDef {
// #[repr(C, packed)]
// struct EnumDef {
// count: u8,
// vals: [u16],
// }
// SAFETY: `enum_offset` has to point to a beginning of a valid enum
// definition inside `ENUM_DEFS`.
unsafe {
let ptr = ENUM_DEFS.as_ptr().add(enum_offset as usize);
let count = ptr.offset(0).read() as usize;
let vals = ptr.offset(1);
EnumDef {
values: slice::from_raw_parts(vals.cast(), count),
}
}
}

@ -0,0 +1,236 @@
use core::convert::TryFrom;
use crate::{
error::Error,
micropython::{
buffer::{Buffer, BufferMut},
gc::Gc,
iter::{Iter, IterBuf},
list::List,
obj::Obj,
qstr::Qstr,
},
util,
};
use super::{
defs::{FieldDef, FieldType, MsgDef},
obj::{MsgObj},
zigzag,
};
#[no_mangle]
pub extern "C" fn protobuf_len(obj: Obj) -> Obj {
util::try_or_raise(|| {
let obj = Gc::<MsgObj>::try_from(obj)?;
let stream = &mut CounterStream { len: 0 };
Encoder.encode_message(stream, &obj.def(), &obj)?;
Ok(stream.len.into())
})
}
#[no_mangle]
pub extern "C" fn protobuf_encode(buf: Obj, obj: Obj) -> Obj {
util::try_or_raise(|| {
let obj = Gc::<MsgObj>::try_from(obj)?;
let buf = &mut BufferMut::try_from(buf)?;
let stream = &mut BufferStream::new(unsafe {
// SAFETY: We assume there are no other refs into `buf` at this point. This
// specifically means that no fields of `obj` should reference `buf` memory.
buf.as_mut()
});
Encoder.encode_message(stream, &obj.def(), &obj)?;
Ok(stream.len().into())
})
}
pub struct Encoder;
impl Encoder {
pub fn encode_message(
&self,
stream: &mut impl OutputStream,
msg: &MsgDef,
obj: &MsgObj,
) -> Result<(), Error> {
for field in msg.fields {
let field_name = Qstr::from(field.name);
// Lookup the field by name. If not set or None, skip.
let field_value = match obj.map().get(field_name) {
Ok(value) => value,
Err(_) => continue,
};
if field_value == Obj::const_none() {
continue;
}
let field_key = {
let prim_type = field.get_type().primitive_type();
let prim_type = prim_type as u64;
let field_tag = field.tag as u64;
field_tag << 3 | prim_type
};
if field.is_repeated() {
let mut iter_buf = IterBuf::new();
let iter = Iter::try_from_obj_with_buf(field_value, &mut iter_buf)?;
for iter_value in iter {
stream.write_uvarint(field_key)?;
self.encode_field(stream, field, iter_value)?;
}
} else {
stream.write_uvarint(field_key)?;
self.encode_field(stream, field, field_value)?;
}
}
Ok(())
}
pub fn encode_field(
&self,
stream: &mut impl OutputStream,
field: &FieldDef,
value: Obj,
) -> Result<(), Error> {
match field.get_type() {
FieldType::UVarInt | FieldType::Enum(_) => {
let uint = u64::try_from(value)?;
stream.write_uvarint(uint)?;
}
FieldType::SVarInt => {
let sint = i64::try_from(value)?;
let uint = zigzag::to_unsigned(sint);
stream.write_uvarint(uint)?;
}
FieldType::Bool => {
let boolean = bool::try_from(value)?;
let uint = if boolean { 1 } else { 0 };
stream.write_uvarint(uint)?;
}
FieldType::Bytes | FieldType::String => {
if Gc::<List>::try_from(value).is_ok() {
// As an optimization, we support defining bytes and string
// fields also as a list-of-values. Serialize them as if
// concatenated.
let mut iter_buf = IterBuf::new();
// Serialize the total length of the buffer.
let mut len = 0;
let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?;
for value in iter {
let buffer = Buffer::try_from(value)?;
len += buffer.len();
}
stream.write_uvarint(len as u64)?;
// Serialize the buffers one-by-one.
let iter = Iter::try_from_obj_with_buf(value, &mut iter_buf)?;
for value in iter {
let buffer = Buffer::try_from(value)?;
stream.write(&buffer)?;
}
} else {
// Single length-delimited field.
let buffer = Buffer::try_from(value)?;
stream.write_uvarint(buffer.len() as u64)?;
stream.write(&buffer)?;
}
}
FieldType::Msg(msg_type) => {
let value = &Gc::<MsgObj>::try_from(value)?;
// Calculate the message size by encoding it through `CountingWriter`.
let counter = &mut CounterStream { len: 0 };
self.encode_message(counter, &msg_type, value)?;
// Encode the message as length-delimited bytes.
stream.write_uvarint(counter.len as u64)?;
self.encode_message(stream, &msg_type, value)?;
}
}
Ok(())
}
}
pub trait OutputStream {
fn write(&mut self, buf: &[u8]) -> Result<(), Error>;
fn write_byte(&mut self, val: u8) -> Result<(), Error>;
fn write_uvarint(&mut self, mut num: u64) -> Result<(), Error> {
loop {
let shifted = num >> 7;
let byte = (num & 0x7F) as u8;
if shifted != 0 {
num = shifted;
self.write_byte(byte | 0x80)?;
} else {
break self.write_byte(byte);
}
}
}
}
pub struct CounterStream {
pub len: usize,
}
impl OutputStream for CounterStream {
fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
self.len += buf.len();
Ok(())
}
fn write_byte(&mut self, _val: u8) -> Result<(), Error> {
self.len += 1;
Ok(())
}
}
pub struct BufferStream<'a> {
buf: &'a mut [u8],
pos: usize,
}
impl<'a> BufferStream<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
Self { buf, pos: 0 }
}
pub fn len(&self) -> usize {
self.pos
}
}
impl<'a> OutputStream for BufferStream<'a> {
fn write(&mut self, val: &[u8]) -> Result<(), Error> {
let pos = &mut self.pos;
let len = val.len();
self.buf
.get_mut(*pos..*pos + len)
.map(|buf| {
*pos += len;
buf.copy_from_slice(val);
})
.ok_or(Error::Missing)
}
fn write_byte(&mut self, val: u8) -> Result<(), Error> {
let pos = &mut self.pos;
self.buf
.get_mut(*pos)
.map(|buf| {
*pos += 1;
*buf = val;
})
.ok_or(Error::Missing)
}
}

@ -0,0 +1,5 @@
mod decode;
mod defs;
mod encode;
mod obj;
mod zigzag;

@ -0,0 +1,264 @@
use core::convert::TryFrom;
use crate::{
error::Error,
micropython::{
dict::Dict,
ffi,
gc::Gc,
map::Map,
obj::{Obj, ObjBase},
qstr::Qstr,
typ::Type,
},
util,
};
use super::decode::Decoder;
use super::defs::{find_name_by_msg_offset, get_msg, MsgDef};
#[repr(C)]
pub struct MsgObj {
base: ObjBase,
map: Map,
msg_wire_id: Option<u16>,
msg_offset: u16,
}
impl MsgObj {
pub fn alloc_with_capacity(capacity: usize, msg: &MsgDef) -> Gc<Self> {
Gc::new(Self {
base: Self::obj_type().to_base(),
map: Map::with_capacity(capacity),
msg_wire_id: msg.wire_id,
msg_offset: msg.offset,
})
}
pub fn map(&self) -> &Map {
&self.map
}
pub fn map_mut(&mut self) -> &mut Map {
&mut self.map
}
pub fn def(&self) -> MsgDef {
unsafe { get_msg(self.msg_offset) }
}
fn obj_type() -> &'static Type {
static TYPE: Type = obj_type! {
name: Qstr::MP_QSTR_Msg,
attr_fn: msg_obj_attr,
};
&TYPE
}
}
impl MsgObj {
fn getattr(&self, attr: Qstr) -> Result<Obj, Error> {
if let Ok(obj) = self.map.get(attr) {
// Message field was found, return its value.
return Ok(obj);
}
// Built-in attribute
match attr {
Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => {
// Return the wire ID of this message def, or None if not set.
Ok(self.msg_wire_id.map_or(Obj::const_none(), |wire_id| wire_id.into()))
}
Qstr::MP_QSTR_MESSAGE_NAME => {
// Return the qstr name of this message def
Ok(Qstr::from_u16(find_name_by_msg_offset(self.msg_offset)?).into())
}
Qstr::MP_QSTR___dict__ => {
// Conversion to dict. Allocate a new dict object with a copy of our map
// and return it. This is a bit different from how uPy does it now, because
// we're returning a mutable dict.
Ok(Gc::new(Dict::with_map(self.map.clone())).into())
}
_ => { Err(Error::Missing) }
}
}
fn setattr(&mut self, attr: Qstr, value: Obj) -> Result<(), Error> {
if value == Obj::const_null() {
// this would be a delattr
return Err(Error::InvalidOperation);
}
if self.map.contains_key(attr) {
self.map.set(attr, value);
Ok(())
} else {
Err(Error::Missing)
}
}
}
impl Into<Obj> for Gc<MsgObj> {
fn into(self) -> Obj {
// SAFETY:
// - We are GC-allocated.
// - We are `repr(C)`.
// - We have a `base` as the first field with the correct type.
unsafe { Obj::from_ptr(Self::into_raw(self).cast()) }
}
}
impl TryFrom<Obj> for Gc<MsgObj> {
type Error = Error;
fn try_from(value: Obj) -> Result<Self, Self::Error> {
if MsgObj::obj_type().is_type_of(value) {
// SAFETY: We assume that if `value` is an object pointer with the correct type,
// it is always GC-allocated.
let this = unsafe { Gc::from_raw(value.as_ptr().cast()) };
Ok(this)
} else {
Err(Error::InvalidType)
}
}
}
unsafe extern "C" fn msg_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) {
util::try_or_raise(|| {
let mut this = Gc::<MsgObj>::try_from(self_in)?;
let attr = Qstr::from_u16(attr as _);
unsafe {
if dest.read() == Obj::const_null() {
// Load attribute
dest.write(this.getattr(attr)?);
} else {
let value = dest.offset(1).read();
// Store attribute
Gc::as_mut(&mut this).setattr(attr, value)?;
dest.write(Obj::const_null());
}
Ok(())
}
})
}
#[repr(C)]
pub struct MsgDefObj {
base: ObjBase,
def: MsgDef,
}
impl MsgDefObj {
pub fn alloc(def: MsgDef) -> Gc<Self> {
Gc::new(Self {
base: Self::obj_type().to_base(),
def,
})
}
pub fn msg(&self) -> &MsgDef {
&self.def
}
fn obj_type() -> &'static Type {
static TYPE: Type = obj_type! {
name: Qstr::MP_QSTR_MsgDef,
attr_fn: msg_def_obj_attr,
call_fn: msg_def_obj_call,
};
&TYPE
}
}
impl Into<Obj> for Gc<MsgDefObj> {
fn into(self) -> Obj {
// SAFETY:
// - We are GC-allocated.
// - We are `repr(C)`.
// - We have a `base` as the first field with the correct type.
unsafe { Obj::from_ptr(Self::into_raw(self).cast()) }
}
}
impl TryFrom<Obj> for Gc<MsgDefObj> {
type Error = Error;
fn try_from(value: Obj) -> Result<Self, Self::Error> {
if MsgDefObj::obj_type().is_type_of(value) {
// SAFETY: We assume that if `value` is an object pointer with the correct type,
// it is always GC-allocated.
let this = unsafe { Gc::from_raw(value.as_ptr().cast()) };
Ok(this)
} else {
Err(Error::InvalidType)
}
}
}
unsafe extern "C" fn msg_def_obj_attr(self_in: Obj, attr: ffi::qstr, dest: *mut Obj) {
util::try_or_raise(|| {
let this= Gc::<MsgDefObj>::try_from(self_in)?;
let attr = Qstr::from_u16(attr as _);
if unsafe { dest.read() } != Obj::const_null() {
return Err(Error::InvalidOperation);
}
match attr {
Qstr::MP_QSTR_MESSAGE_NAME => {
// Return the qstr name of this message def
let name = Qstr::from_u16(find_name_by_msg_offset(this.def.offset)?);
unsafe { dest.write(name.into()); };
}
Qstr::MP_QSTR_MESSAGE_WIRE_TYPE => {
// Return the wire type of this message def
let wire_id_obj = this
.def
.wire_id
.map_or_else(Obj::const_none, |wire_id| wire_id.into());
unsafe { dest.write(wire_id_obj); };
}
Qstr::MP_QSTR_is_type_of => {
// Return the is_type_of bound method
// dest[0] = function_obj
// dest[1] = self
unsafe {
dest.write(MSG_DEF_OBJ_IS_TYPE_OF_OBJ.to_obj());
dest.offset(1).write(self_in);
}
}
_ => { return Err(Error::Missing); }
}
Ok(())
});
}
unsafe extern "C" fn msg_def_obj_call(
self_in: Obj,
n_args: usize,
n_kw: usize,
args: *const Obj,
) -> Obj {
util::try_with_args_and_kwargs_inline(n_args, n_kw, args, |_args, kwargs| {
let this = Gc::<MsgDefObj>::try_from(self_in)?;
let decoder = Decoder {
enable_experimental: true,
};
let obj = decoder.message_from_values(kwargs, this.msg())?;
Ok(obj)
})
}
unsafe extern "C" fn msg_def_obj_is_type_of(self_in: Obj, obj: Obj) -> Obj {
util::try_or_raise(|| {
let this = Gc::<MsgDefObj>::try_from(self_in)?;
let msg = Gc::<MsgObj>::try_from(obj);
match msg {
Ok(msg) if msg.msg_offset == this.def.offset => Ok(Obj::const_true()),
_ => Ok(Obj::const_false()),
}
})
}
static MSG_DEF_OBJ_IS_TYPE_OF_OBJ: ffi::mp_obj_fun_builtin_fixed_t = obj_fn_2!(msg_def_obj_is_type_of);

@ -0,0 +1,9 @@
// https://developers.google.com/protocol-buffers/docs/encoding#signed_integers
pub fn to_unsigned(sint: i64) -> u64 {
((sint << 1) ^ (sint >> 63)) as u64
}
pub fn to_signed(uint: u64) -> i64 {
((uint >> 1) as i64) ^ (-((uint & 1) as i64))
}

@ -9,7 +9,7 @@ use crate::{
},
};
pub fn try_or_raise(func: impl FnOnce() -> Result<Obj, Error>) -> Obj {
pub fn try_or_raise<T>(func: impl FnOnce() -> Result<T, Error>) -> T {
func().unwrap_or_else(|err| raise_value_error(err.as_cstr()))
}

@ -195,6 +195,7 @@ extern const struct _mp_print_t mp_stderr_print;
#define MICROPY_PY_TREZORIO (1)
#define MICROPY_PY_TREZORUI (1)
#define MICROPY_PY_TREZORUTILS (1)
#define MICROPY_PY_TREZORPROTO (1)
#define MP_STATE_PORT MP_STATE_VM

@ -0,0 +1,34 @@
from typing import *
from trezor.protobuf import MessageType
T = TypeVar("T", bound=MessageType)
# extmod/rustmods/modtrezorproto.c
def type_for_name(name: str) -> Type[MessageType]:
"""Find the message definition for the given protobuf name."""
# extmod/rustmods/modtrezorproto.c
def type_for_wire(wire_type: int) -> Type[MessageType]:
"""Find the message definition for the given wire type (numeric
identifier)."""
# extmod/rustmods/modtrezorproto.c
def decode(
buffer: bytes,
msg_type: Type[T],
enable_experimental: bool,
) -> T:
"""Decode data in the buffer into the specified message type."""
# extmod/rustmods/modtrezorproto.c
def encoded_length(msg: MessageType) -> int:
"""Calculate length of encoding of the specified message."""
# extmod/rustmods/modtrezorproto.c
def encode(buffer: bytearray, msg: MessageType) -> int:
"""Encode the message into the specified buffer. Return length of
encoding."""
Loading…
Cancel
Save