1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 07:50:57 +00:00

refactor(core): Switch to new Protobuf API

This commit is contained in:
Jan Pochyla 2021-03-23 13:36:03 +01:00 committed by matejcik
parent 8a5cb41060
commit 02aa14fc04
14 changed files with 237 additions and 677 deletions

View File

@ -145,6 +145,7 @@ async def handle_DoPreauthorized(
req = await ctx.call_any(PreauthorizedRequest(), *wire_types) req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE ctx.iface, req.MESSAGE_WIRE_TYPE
) )

View File

@ -1,7 +1,8 @@
import protobuf import protobuf
import storage.cache import storage.cache
from trezor import messages, utils from trezor import protobuf
from trezor.messages import MessageType from trezor.enums import MessageType
from trezor.utils import ensure
if False: if False:
from typing import Iterable from typing import Iterable
@ -16,9 +17,12 @@ def is_set() -> bool:
def set(auth_message: protobuf.MessageType) -> None: def set(auth_message: protobuf.MessageType) -> None:
buffer = bytearray(protobuf.count_message(auth_message)) buffer = protobuf.dump_message_buffer(auth_message)
writer = utils.BufferWriter(buffer)
protobuf.dump_message(writer, auth_message) # only wire-level messages can be stored as authorization
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that mypy knows as well
storage.cache.set( storage.cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE, storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"), auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
@ -32,11 +36,8 @@ def get() -> protobuf.MessageType | None:
return None return None
msg_wire_type = int.from_bytes(stored_auth_type, "big") msg_wire_type = int.from_bytes(stored_auth_type, "big")
msg_type = messages.get_type(msg_wire_type)
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA) buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
reader = utils.BufferReader(buffer) return protobuf.load_message_buffer(buffer, msg_wire_type)
return protobuf.load_message(reader, msg_type)
def get_wire_types() -> Iterable[int]: def get_wire_types() -> Iterable[int]:

View File

@ -50,7 +50,7 @@ async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
# and then is used to store the derivations if applicable # and then is used to store the derivations if applicable
plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys) plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size") utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
del msg.tx_enc_keys msg.tx_enc_keys = b""
# If return only derivations do tx_priv * view_pub # If return only derivations do tx_priv * view_pub
if do_deriv: if do_deriv:

View File

@ -134,7 +134,7 @@ def gen_hmac_vini(
used only once and hard to check. I.e., indices in step 2 used only once and hard to check. I.e., indices in step 2
are uncheckable, decoy keys in step 9 are just random keys. are uncheckable, decoy keys in step 9 are just random keys.
""" """
import protobuf from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer() kwriter = get_keccak_writer()
@ -146,7 +146,7 @@ def gen_hmac_vini(
src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index] src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index]
] ]
protobuf.dump_message(kwriter, src_entr) kwriter.write(protobuf.dump_message_buffer(src_entr))
src_entr.outputs = real_outputs src_entr.outputs = real_outputs
src_entr.real_out_additional_tx_keys = real_additional src_entr.real_out_additional_tx_keys = real_additional
kwriter.write(vini_bin) kwriter.write(vini_bin)
@ -162,11 +162,11 @@ def gen_hmac_vouti(
""" """
Generates HMAC for (TxDestinationEntry[i] || tx.vout[i]) Generates HMAC for (TxDestinationEntry[i] || tx.vout[i])
""" """
import protobuf from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer() kwriter = get_keccak_writer()
protobuf.dump_message(kwriter, dst_entr) kwriter.write(protobuf.dump_message_buffer(dst_entr))
kwriter.write(tx_out_bin) kwriter.write(tx_out_bin)
hmac_key_vouti = hmac_key_txout(key, idx) hmac_key_vouti = hmac_key_txout(key, idx)
@ -180,11 +180,11 @@ def gen_hmac_tsxdest(
""" """
Generates HMAC for TxDestinationEntry[i] Generates HMAC for TxDestinationEntry[i]
""" """
import protobuf from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer() kwriter = get_keccak_writer()
protobuf.dump_message(kwriter, dst_entr) kwriter.write(protobuf.dump_message_buffer(dst_entr))
hmac_key = hmac_key_txdst(key, idx) hmac_key = hmac_key_txdst(key, idx)
hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest()) hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest())

View File

@ -116,7 +116,7 @@ async def init_transaction(
from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
rsig_data = MoneroTransactionRsigData(offload_type=state.rsig_offload) rsig_data = MoneroTransactionRsigData(offload_type=int(state.rsig_offload))
return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data) return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data)
@ -273,11 +273,11 @@ def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData):
""" """
Generate master key H( H(TsxData || tx_priv) || rand ) Generate master key H( H(TsxData || tx_priv) || rand )
""" """
import protobuf from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer from apps.monero.xmr.keccak_hasher import get_keccak_writer
writer = get_keccak_writer() writer = get_keccak_writer()
protobuf.dump_message(writer, tsx_data) writer.write(protobuf.dump_message_buffer(tsx_data))
writer.write(crypto.encodeint(state.tx_priv)) writer.write(crypto.encodeint(state.tx_priv))
master_key = crypto.keccak_2hash( master_key = crypto.keccak_2hash(

View File

@ -62,9 +62,6 @@ class MemoryReaderWriter:
return nread return nread
async def areadinto(self, buf):
return self.readinto(buf)
def write(self, buf): def write(self, buf):
nwritten = len(buf) nwritten = len(buf)
nall = len(self.buffer) nall = len(self.buffer)
@ -98,9 +95,6 @@ class MemoryReaderWriter:
self.ndata += nwritten self.ndata += nwritten
return nwritten return nwritten
async def awrite(self, buf):
return self.write(buf)
def get_buffer(self): def get_buffer(self):
mv = memoryview(self.buffer) mv = memoryview(self.buffer)
return mv[self.offset : self.woffset] return mv[self.offset : self.woffset]

View File

@ -1,456 +0,0 @@
"""
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
"""
if False:
from typing import (
Any,
Callable,
Iterable,
TypeVar,
Union,
)
from typing_extensions import Protocol
class Reader(Protocol):
def readinto(self, buf: bytearray) -> int:
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
class Writer(Protocol):
def write(self, buf: bytes) -> int:
"""
Writes all bytes from `buf`, or raises `EOFError`.
"""
WriteMethod = Callable[[bytes], Any]
_UVARINT_BUFFER = bytearray(1)
def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
reader.readinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
def dump_uvarint(write: WriteMethod, n: int) -> None:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = 1
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
write(buffer)
n = shifted
def count_uvarint(n: int) -> int:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7F:
return 1
if n <= 0x3FFF:
return 2
if n <= 0x1F_FFFF:
return 3
if n <= 0xFFF_FFFF:
return 4
if n <= 0x7_FFFF_FFFF:
return 5
if n <= 0x3FF_FFFF_FFFF:
return 6
if n <= 0x1_FFFF_FFFF_FFFF:
return 7
if n <= 0xFF_FFFF_FFFF_FFFF:
return 8
if n <= 0x7FFF_FFFF_FFFF_FFFF:
return 9
raise ValueError
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint: int) -> int:
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint: int) -> int:
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType:
WIRE_TYPE = 0
class SVarintType:
WIRE_TYPE = 0
class BoolType:
WIRE_TYPE = 0
class EnumType:
WIRE_TYPE = 0
def __init__(self, name: str, enum_values: Iterable[int]) -> None:
self.enum_values = enum_values
def validate(self, fvalue: int) -> int:
if fvalue in self.enum_values:
return fvalue
else:
raise TypeError("Invalid enum value")
class BytesType:
WIRE_TYPE = 2
class UnicodeType:
WIRE_TYPE = 2
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type["MessageType"],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type["MessageType"], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
class MessageType:
WIRE_TYPE = 2
UNSTABLE = False
# Type id for the wire codec.
# Technically, not every protobuf message has this.
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls) -> FieldDict:
return {}
@classmethod
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
if cls in field_cache:
fields = field_cache[cls]
else:
fields = cls.get_fields()
field_cache[cls] = fields
for _, field_type, _ in fields.values():
if isinstance(field_type, MessageType):
field_type.cache_subordinate_types(field_cache)
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__
class LimitedReader:
def __init__(self, reader: Reader, limit: int) -> None:
self.reader = reader
self.limit = limit
def readinto(self, buf: bytearray) -> int:
if self.limit < len(buf):
raise EOFError
else:
nread = self.reader.readinto(buf)
self.limit -= nread
return nread
FLAG_REPEATED = object()
FLAG_REQUIRED = object()
FLAG_EXPERIMENTAL = object()
def load_message(
reader: Reader,
msg_type: type[LoadedMessageType],
field_cache: FieldCache | None = None,
experimental_enabled: bool = True,
) -> LoadedMessageType:
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
if fields is None:
fields = msg_type.get_fields()
field_cache[msg_type] = fields
if msg_type.UNSTABLE and not experimental_enabled:
raise ValueError # experimental messages not enabled
# we need to avoid calling __init__, which enforces required arguments
msg: LoadedMessageType = object.__new__(msg_type)
# pre-seed the object with defaults
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REPEATED:
fdefault = []
elif fdefault is FLAG_EXPERIMENTAL:
fdefault = None
setattr(msg, fname, fdefault)
if False:
SingularValue = Union[int, bool, bytearray, str, MessageType]
Value = Union[SingularValue, list[SingularValue]]
fvalue: Value = 0
while True:
try:
fkey = load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
load_uvarint(reader)
elif wtype == 2:
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
else:
raise ValueError
continue
fname, ftype, fdefault = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled:
raise ValueError # experimental fields not enabled
ivalue = load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif isinstance(ftype, EnumType):
fvalue = ftype.validate(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
reader.readinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
reader.readinto(fvalue)
fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(
LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled
)
else:
raise TypeError # field type is unknown
if fdefault is FLAG_REPEATED:
getattr(msg, fname).append(fvalue)
else:
setattr(msg, fname, fvalue)
for fname, _, _ in fields.values():
if getattr(msg, fname) is FLAG_REQUIRED:
# The message is intended to be user-facing when decoding from wire,
# but not when used internally.
raise ValueError("Required field '{}' was not received".format(fname))
return msg
def dump_message(
writer: Writer, msg: MessageType, field_cache: FieldCache | None = None
) -> None:
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
dump_uvarint(writer.write, fkey)
if ftype is UVarintType:
dump_uvarint(writer.write, svalue)
elif ftype is SVarintType:
dump_uvarint(writer.write, sint_to_uint(svalue))
elif ftype is BoolType:
dump_uvarint(writer.write, int(svalue))
elif isinstance(ftype, EnumType):
dump_uvarint(writer.write, svalue)
elif ftype is BytesType:
if isinstance(svalue, list):
dump_uvarint(writer.write, _count_bytes_list(svalue))
for sub_svalue in svalue:
writer.write(sub_svalue)
else:
dump_uvarint(writer.write, len(svalue))
writer.write(svalue)
elif ftype is UnicodeType:
svalue = svalue.encode()
dump_uvarint(writer.write, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType):
ffields = field_cache.get(ftype)
if ffields is None:
ffields = ftype.get_fields()
field_cache[ftype] = ffields
dump_uvarint(writer.write, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache)
else:
raise TypeError
def count_message(msg: MessageType, field_cache: FieldCache | None = None) -> int:
nbytes = 0
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
# length of all the field keys
nbytes += count_uvarint(fkey) * len(fvalue)
if ftype is UVarintType:
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is SVarintType:
for svalue in fvalue:
nbytes += count_uvarint(sint_to_uint(svalue))
elif ftype is BoolType:
for svalue in fvalue:
nbytes += count_uvarint(int(svalue))
elif isinstance(ftype, EnumType):
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is BytesType:
for svalue in fvalue:
if isinstance(svalue, list):
svalue = _count_bytes_list(svalue)
else:
svalue = len(svalue)
nbytes += count_uvarint(svalue)
nbytes += svalue
elif ftype is UnicodeType:
for svalue in fvalue:
svalue = len(svalue.encode())
nbytes += count_uvarint(svalue)
nbytes += svalue
elif issubclass(ftype, MessageType):
for svalue in fvalue:
fsize = count_message(svalue, field_cache)
nbytes += count_uvarint(fsize)
nbytes += fsize
else:
raise TypeError
return nbytes
def _count_bytes_list(svalue: list[bytes]) -> int:
res = 0
for x in svalue:
res += len(x)
return res

View File

@ -0,0 +1,42 @@
import trezorproto
decode = trezorproto.decode
encode = trezorproto.encode
encoded_length = trezorproto.encoded_length
type_for_name = trezorproto.type_for_name
type_for_wire = trezorproto.type_for_wire
# XXX
# Note that MessageType "subclasses" are not true subclasses, but instead instances
# of the built-in metaclass MsgDef. MessageType instances are in fact instances of
# the built-in type Msg. That is why isinstance checks do not work, and instead the
# MessageTypeSubclass.is_type_of() method must be used.
if False:
from typing import Type, TypeGuard, TypeVar
T = TypeVar("T", bound="MessageType")
class MsgDef(type):
@classmethod
def is_type_of(cls: Type[Type[T]], msg: "MessageType") -> TypeGuard[T]:
"""Identify if the provided message belongs to this type."""
raise NotImplementedError
class MessageType(metaclass=MsgDef):
MESSAGE_NAME: str = "MessageType"
MESSAGE_WIRE_TYPE: int | None = None
def load_message_buffer(
buffer: bytes,
msg_wire_type: int,
experimental_enabled: bool = True,
) -> MessageType:
msg_type = type_for_wire(msg_wire_type)
return decode(buffer, msg_type, experimental_enabled)
def dump_message_buffer(msg: MessageType) -> bytearray:
buffer = bytearray(encoded_length(msg))
encode(buffer, msg)
return buffer

View File

@ -145,10 +145,6 @@ class HashWriter:
def write(self, buf: bytes) -> None: # alias for extend() def write(self, buf: bytes) -> None: # alias for extend()
self.ctx.update(buf) self.ctx.update(buf)
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
self.ctx.update(buf)
return len(buf)
def get_digest(self) -> bytes: def get_digest(self) -> bytes:
return self.ctx.digest() return self.ctx.digest()

View File

@ -35,11 +35,10 @@ reads the message's header. When the message type is known the first handler is
""" """
import protobuf
from storage.cache import InvalidSessionError from storage.cache import InvalidSessionError
from trezor import log, loop, messages, utils, workflow from trezor.enums import FailureType
from trezor.messages import FailureType from trezor.messages import Failure
from trezor.messages.Failure import Failure from trezor import log, loop, protobuf, utils, workflow
from trezor.wire import codec_v1 from trezor.wire import codec_v1
from trezor.wire.errors import ActionCancelled, DataError, Error from trezor.wire.errors import ActionCancelled, DataError, Error
@ -74,17 +73,19 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
if False: if False:
from typing import Protocol from typing import Protocol, TypeVar
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class GenericContext(Protocol): class GenericContext(Protocol):
async def call( async def call(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,
expected_type: type[protobuf.LoadedMessageType], expected_type: type[protobuf.MessageType],
) -> Any: ) -> Any:
... ...
async def read(self, expected_type: type[protobuf.LoadedMessageType]) -> Any: async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
... ...
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
@ -96,15 +97,14 @@ if False:
def _wrap_protobuf_load( def _wrap_protobuf_load(
reader: protobuf.Reader, buffer: bytes,
expected_type: type[protobuf.LoadedMessageType], expected_type: type[LoadedMessageType],
field_cache: protobuf.FieldCache | None = None, ) -> LoadedMessageType:
) -> protobuf.LoadedMessageType:
try: try:
return protobuf.load_message( return protobuf.decode(buffer, expected_type, experimental_enabled)
reader, expected_type, field_cache, experimental_enabled
)
except Exception as e: except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args: if e.args:
raise DataError("Failed to decode message: {}".format(e.args[0])) raise DataError("Failed to decode message: {}".format(e.args[0]))
else: else:
@ -139,19 +139,15 @@ class Context:
self.iface = iface self.iface = iface
self.sid = sid self.sid = sid
self.buffer = buffer self.buffer = buffer
self.buffer_writer = utils.BufferWriter(self.buffer)
self._field_cache: protobuf.FieldCache = {}
async def call( async def call(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,
expected_type: type[protobuf.LoadedMessageType], expected_type: type[LoadedMessageType],
field_cache: protobuf.FieldCache | None = None, ) -> LoadedMessageType:
) -> protobuf.LoadedMessageType: await self.write(msg)
await self.write(msg, field_cache)
del msg del msg
return await self.read(expected_type, field_cache) return await self.read(expected_type)
async def call_any( async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int self, msg: protobuf.MessageType, *expected_wire_types: int
@ -160,21 +156,17 @@ class Context:
del msg del msg
return await self.read_any(expected_wire_types) return await self.read_any(expected_wire_types)
async def read_from_wire(self) -> codec_v1.Message: def read_from_wire(self) -> Awaitable[codec_v1.Message]:
return await codec_v1.read_message(self.iface, self.buffer) return codec_v1.read_message(self.iface, self.buffer)
async def read( async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
self,
expected_type: type[protobuf.LoadedMessageType],
field_cache: protobuf.FieldCache | None = None,
) -> protobuf.LoadedMessageType:
if __debug__: if __debug__:
log.debug( log.debug(
__name__, __name__,
"%s:%x expect: %s", "%s:%x expect: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.sid,
expected_type, expected_type.MESSAGE_NAME,
) )
# Load the full message into a buffer, parse out type and data payload # Load the full message into a buffer, parse out type and data payload
@ -191,13 +183,13 @@ class Context:
"%s:%x read: %s", "%s:%x read: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.sid,
expected_type, expected_type.MESSAGE_NAME,
) )
workflow.idle_timer.touch() workflow.idle_timer.touch()
# look up the protobuf class and parse the message # look up the protobuf class and parse the message
return _wrap_protobuf_load(msg.data, expected_type, field_cache) return _wrap_protobuf_load(msg.data, expected_type)
async def read_any( async def read_any(
self, expected_wire_types: Iterable[int] self, expected_wire_types: Iterable[int]
@ -220,11 +212,15 @@ class Context:
raise UnexpectedMessageError(msg) raise UnexpectedMessageError(msg)
# find the protobuf type # find the protobuf type
exptype = messages.get_type(msg.type) exptype = protobuf.type_for_wire(msg.type)
if __debug__: if __debug__:
log.debug( log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype __name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
exptype.MESSAGE_NAME,
) )
workflow.idle_timer.touch() workflow.idle_timer.touch()
@ -232,41 +228,36 @@ class Context:
# parse the message and return it # parse the message and return it
return _wrap_protobuf_load(msg.data, exptype) return _wrap_protobuf_load(msg.data, exptype)
async def write( async def write(self, msg: protobuf.MessageType) -> None:
self,
msg: protobuf.MessageType,
field_cache: protobuf.FieldCache | None = None,
) -> None:
if __debug__: if __debug__:
log.debug( log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg __name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
) )
if field_cache is None: # cannot write message without wire type
field_cache = self._field_cache assert msg.MESSAGE_WIRE_TYPE is not None
# write the message msg_size = protobuf.encoded_length(msg)
msg_size = protobuf.count_message(msg, field_cache)
# prepare buffer if msg_size <= len(self.buffer):
if msg_size <= len(self.buffer_writer.buffer):
# reuse preallocated # reuse preallocated
buffer_writer = self.buffer_writer buffer = self.buffer
else: else:
# message is too big, we need to allocate a new buffer # message is too big, we need to allocate a new buffer
buffer_writer = utils.BufferWriter(bytearray(msg_size)) buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
buffer_writer.seek(0)
protobuf.dump_message(buffer_writer, msg, field_cache)
await codec_v1.write_message( await codec_v1.write_message(
self.iface, self.iface,
msg.MESSAGE_WIRE_TYPE, msg.MESSAGE_WIRE_TYPE,
memoryview(buffer_writer.buffer)[:msg_size], memoryview(buffer)[:msg_size],
) )
# make sure we don't keep around fields of all protobuf types ever
self._field_cache.clear()
def wait(self, *tasks: Awaitable) -> Any: def wait(self, *tasks: Awaitable) -> Any:
""" """
Wait until one of the passed tasks finishes, and return the result, Wait until one of the passed tasks finishes, and return the result,
@ -301,8 +292,8 @@ async def _handle_single_message(
""" """
if __debug__: if __debug__:
try: try:
msg_type = messages.get_type(msg.type).__name__ msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except KeyError: except Exception:
msg_type = "%d - unknown message type" % msg.type msg_type = "%d - unknown message type" % msg.type
log.debug( log.debug(
__name__, __name__,
@ -328,7 +319,7 @@ async def _handle_single_message(
try: try:
# Find a protobuf.MessageType subclass that describes this # Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found. # message. Raises if the type is not found.
req_type = messages.get_type(msg.type) req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from # Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed. # `req_type`. Raises if the message is malformed.
@ -424,6 +415,11 @@ async def handle_session(
next_msg = await _handle_single_message( next_msg = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session ctx, msg, use_workflow=not is_debug_session
) )
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally: finally:
if not __debug__ or not is_debug_session: if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise. # Unload modules imported by the workflow. Should not raise.

View File

@ -23,7 +23,7 @@ class CodecError(Exception):
class Message: class Message:
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None: def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype self.type = mtype
self.data = mdata self.data = mdata
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if read_and_throw_away: if read_and_throw_away:
raise CodecError("Message too large") raise CodecError("Message too large")
return Message(mtype, utils.BufferReader(mdata)) return Message(mtype, mdata)
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:

View File

@ -1,129 +0,0 @@
from common import *
import protobuf
from trezor.utils import BufferReader, BufferWriter
class Message(protobuf.MessageType):
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
self.sint_field = sint_field
self.enum_field = enum_field
@classmethod
def get_fields(cls):
return {
1: ("sint_field", protobuf.SVarintType, 0),
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
}
class MessageWithRequiredAndDefault(protobuf.MessageType):
def __init__(self, required_field, default_field) -> None:
self.required_field = required_field
self.default_field = default_field
@classmethod
def get_fields(cls):
return {
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
2: ("default_field", protobuf.SVarintType, -1),
}
def load_uvarint(data: bytes) -> int:
reader = BufferReader(data)
return protobuf.load_uvarint(reader)
def dump_uvarint(value: int) -> bytearray:
w = bytearray()
protobuf.dump_uvarint(w.extend, value)
return w
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.count_message(msg)
buffer = bytearray(length)
protobuf.dump_message(BufferWriter(buffer), msg)
return buffer
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
return protobuf.load_message(BufferReader(buffer), msg_type)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(ValueError):
dump_uvarint(-1)
def test_load_uvarint(self):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_sint_uint(self):
self.assertEqual(protobuf.uint_to_sint(0), 0)
self.assertEqual(protobuf.sint_to_uint(0), 0)
self.assertEqual(protobuf.sint_to_uint(-1), 1)
self.assertEqual(protobuf.sint_to_uint(1), 2)
self.assertEqual(protobuf.uint_to_sint(1), -1)
self.assertEqual(protobuf.uint_to_sint(2), 1)
# roundtrip:
self.assertEqual(
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
)
self.assertEqual(
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
)
def test_validate_enum(self):
# ok message:
msg = Message(-42, 5)
msg_encoded = dump_message(msg)
nmsg = load_message(Message, msg_encoded)
self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value:
msg = Message(-42, 42)
msg_encoded = dump_message(msg)
with self.assertRaises(TypeError):
load_message(Message, msg_encoded)
def test_required(self):
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, 2)
# try a message without the required_field
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(MessageWithRequiredAndDefault, msg_encoded)
# try a message without the default field
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, -1)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,115 @@
from common import *
from trezor import protobuf
from trezor.messages import WebAuthnCredential, EosAsset, Failure, SignMessage
def load_uvarint32(data: bytes) -> int:
# use known uint32 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, WebAuthnCredential, False)
return msg.index
def load_uvarint64(data: bytes) -> int:
# use known uint64 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, EosAsset, False)
return msg.symbol
def dump_uvarint32(value: int) -> bytearray:
# use known uint32 field in an all-optional message
msg = WebAuthnCredential(index=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
return buffer[1:]
def dump_uvarint64(value: int) -> bytearray:
# use known uint64 field in an all-optional message
msg = EosAsset(symbol=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
return buffer[1:]
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
return buffer
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
return protobuf.decode(buffer, msg_type, False)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(ValueError):
dump_uvarint(-1)
def test_load_uvarint(self):
for load_uvarint in (load_uvarint32, load_uvarint64):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_validate_enum(self):
# ok message:
msg = Failure(code=7)
msg_encoded = dump_message(msg)
nmsg = load_message(Failure, msg_encoded)
self.assertEqual(msg.code, nmsg.code)
# bad enum value:
msg = Failure(code=1000)
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(Failure, msg_encoded)
def test_required(self):
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "foo")
self.assertEqual(nmsg.script_type, 1)
# try a message without the required_field
msg = SignMessage(message=None)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(SignMessage, msg_encoded)
# try a message without the default field
msg = SignMessage(message=b"hello")
msg.coin_name = None
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "Bitcoin")
if __name__ == "__main__":
unittest.main()

View File

@ -53,7 +53,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call # e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, b"") self.assertEqual(result.data, b"")
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer, b"\x00" * 64) self.assertEqual(buffer, b"\x00" * 64)
@ -83,7 +83,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call # e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message) self.assertEqual(result.data, message)
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer, message) self.assertEqual(buffer, message)
@ -108,7 +108,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call # e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message) self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours # read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00") self.assertEqual(buffer, b"\x00")
@ -176,7 +176,7 @@ class TestWireCodecV1(unittest.TestCase):
result = e.value.value result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message) self.assertEqual(result.data, message)
def test_read_huge_packet(self): def test_read_huge_packet(self):
PACKET_COUNT = 100_000 PACKET_COUNT = 100_000