refactor(core): Switch to new Protobuf API

pull/1557/head
Jan Pochyla 3 years ago committed by matejcik
parent 8a5cb41060
commit 02aa14fc04

@ -145,6 +145,7 @@ async def handle_DoPreauthorized(
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
)

@ -1,7 +1,8 @@
import protobuf
import storage.cache
from trezor import messages, utils
from trezor.messages import MessageType
from trezor import protobuf
from trezor.enums import MessageType
from trezor.utils import ensure
if False:
from typing import Iterable
@ -16,9 +17,12 @@ def is_set() -> bool:
def set(auth_message: protobuf.MessageType) -> None:
buffer = bytearray(protobuf.count_message(auth_message))
writer = utils.BufferWriter(buffer)
protobuf.dump_message(writer, auth_message)
buffer = protobuf.dump_message_buffer(auth_message)
# only wire-level messages can be stored as authorization
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that mypy knows as well
storage.cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
@ -32,11 +36,8 @@ def get() -> protobuf.MessageType | None:
return None
msg_wire_type = int.from_bytes(stored_auth_type, "big")
msg_type = messages.get_type(msg_wire_type)
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
reader = utils.BufferReader(buffer)
return protobuf.load_message(reader, msg_type)
return protobuf.load_message_buffer(buffer, msg_wire_type)
def get_wire_types() -> Iterable[int]:

@ -50,7 +50,7 @@ async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
# and then is used to store the derivations if applicable
plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
del msg.tx_enc_keys
msg.tx_enc_keys = b""
# If return only derivations do tx_priv * view_pub
if do_deriv:

@ -134,7 +134,7 @@ def gen_hmac_vini(
used only once and hard to check. I.e., indices in step 2
are uncheckable, decoy keys in step 9 are just random keys.
"""
import protobuf
from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer()
@ -146,7 +146,7 @@ def gen_hmac_vini(
src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index]
]
protobuf.dump_message(kwriter, src_entr)
kwriter.write(protobuf.dump_message_buffer(src_entr))
src_entr.outputs = real_outputs
src_entr.real_out_additional_tx_keys = real_additional
kwriter.write(vini_bin)
@ -162,11 +162,11 @@ def gen_hmac_vouti(
"""
Generates HMAC for (TxDestinationEntry[i] || tx.vout[i])
"""
import protobuf
from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer()
protobuf.dump_message(kwriter, dst_entr)
kwriter.write(protobuf.dump_message_buffer(dst_entr))
kwriter.write(tx_out_bin)
hmac_key_vouti = hmac_key_txout(key, idx)
@ -180,11 +180,11 @@ def gen_hmac_tsxdest(
"""
Generates HMAC for TxDestinationEntry[i]
"""
import protobuf
from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer
kwriter = get_keccak_writer()
protobuf.dump_message(kwriter, dst_entr)
kwriter.write(protobuf.dump_message_buffer(dst_entr))
hmac_key = hmac_key_txdst(key, idx)
hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest())

@ -116,7 +116,7 @@ async def init_transaction(
from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
rsig_data = MoneroTransactionRsigData(offload_type=state.rsig_offload)
rsig_data = MoneroTransactionRsigData(offload_type=int(state.rsig_offload))
return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data)
@ -273,11 +273,11 @@ def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData):
"""
Generate master key H( H(TsxData || tx_priv) || rand )
"""
import protobuf
from trezor import protobuf
from apps.monero.xmr.keccak_hasher import get_keccak_writer
writer = get_keccak_writer()
protobuf.dump_message(writer, tsx_data)
writer.write(protobuf.dump_message_buffer(tsx_data))
writer.write(crypto.encodeint(state.tx_priv))
master_key = crypto.keccak_2hash(

@ -62,9 +62,6 @@ class MemoryReaderWriter:
return nread
async def areadinto(self, buf):
return self.readinto(buf)
def write(self, buf):
nwritten = len(buf)
nall = len(self.buffer)
@ -98,9 +95,6 @@ class MemoryReaderWriter:
self.ndata += nwritten
return nwritten
async def awrite(self, buf):
return self.write(buf)
def get_buffer(self):
mv = memoryview(self.buffer)
return mv[self.offset : self.woffset]

@ -1,456 +0,0 @@
"""
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
"""
if False:
from typing import (
Any,
Callable,
Iterable,
TypeVar,
Union,
)
from typing_extensions import Protocol
class Reader(Protocol):
def readinto(self, buf: bytearray) -> int:
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
class Writer(Protocol):
def write(self, buf: bytes) -> int:
"""
Writes all bytes from `buf`, or raises `EOFError`.
"""
WriteMethod = Callable[[bytes], Any]
_UVARINT_BUFFER = bytearray(1)
def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
reader.readinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
def dump_uvarint(write: WriteMethod, n: int) -> None:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = 1
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
write(buffer)
n = shifted
def count_uvarint(n: int) -> int:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7F:
return 1
if n <= 0x3FFF:
return 2
if n <= 0x1F_FFFF:
return 3
if n <= 0xFFF_FFFF:
return 4
if n <= 0x7_FFFF_FFFF:
return 5
if n <= 0x3FF_FFFF_FFFF:
return 6
if n <= 0x1_FFFF_FFFF_FFFF:
return 7
if n <= 0xFF_FFFF_FFFF_FFFF:
return 8
if n <= 0x7FFF_FFFF_FFFF_FFFF:
return 9
raise ValueError
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint: int) -> int:
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint: int) -> int:
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType:
WIRE_TYPE = 0
class SVarintType:
WIRE_TYPE = 0
class BoolType:
WIRE_TYPE = 0
class EnumType:
WIRE_TYPE = 0
def __init__(self, name: str, enum_values: Iterable[int]) -> None:
self.enum_values = enum_values
def validate(self, fvalue: int) -> int:
if fvalue in self.enum_values:
return fvalue
else:
raise TypeError("Invalid enum value")
class BytesType:
WIRE_TYPE = 2
class UnicodeType:
WIRE_TYPE = 2
if False:
MessageTypeDef = Union[
type[UVarintType],
type[SVarintType],
type[BoolType],
EnumType,
type[BytesType],
type[UnicodeType],
type["MessageType"],
]
FieldDef = tuple[str, MessageTypeDef, Any]
FieldDict = dict[int, FieldDef]
FieldCache = dict[type["MessageType"], FieldDict]
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
class MessageType:
WIRE_TYPE = 2
UNSTABLE = False
# Type id for the wire codec.
# Technically, not every protobuf message has this.
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls) -> FieldDict:
return {}
@classmethod
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
if cls in field_cache:
fields = field_cache[cls]
else:
fields = cls.get_fields()
field_cache[cls] = fields
for _, field_type, _ in fields.values():
if isinstance(field_type, MessageType):
field_type.cache_subordinate_types(field_cache)
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__
class LimitedReader:
def __init__(self, reader: Reader, limit: int) -> None:
self.reader = reader
self.limit = limit
def readinto(self, buf: bytearray) -> int:
if self.limit < len(buf):
raise EOFError
else:
nread = self.reader.readinto(buf)
self.limit -= nread
return nread
FLAG_REPEATED = object()
FLAG_REQUIRED = object()
FLAG_EXPERIMENTAL = object()
def load_message(
reader: Reader,
msg_type: type[LoadedMessageType],
field_cache: FieldCache | None = None,
experimental_enabled: bool = True,
) -> LoadedMessageType:
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
if fields is None:
fields = msg_type.get_fields()
field_cache[msg_type] = fields
if msg_type.UNSTABLE and not experimental_enabled:
raise ValueError # experimental messages not enabled
# we need to avoid calling __init__, which enforces required arguments
msg: LoadedMessageType = object.__new__(msg_type)
# pre-seed the object with defaults
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REPEATED:
fdefault = []
elif fdefault is FLAG_EXPERIMENTAL:
fdefault = None
setattr(msg, fname, fdefault)
if False:
SingularValue = Union[int, bool, bytearray, str, MessageType]
Value = Union[SingularValue, list[SingularValue]]
fvalue: Value = 0
while True:
try:
fkey = load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
load_uvarint(reader)
elif wtype == 2:
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
else:
raise ValueError
continue
fname, ftype, fdefault = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled:
raise ValueError # experimental fields not enabled
ivalue = load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif isinstance(ftype, EnumType):
fvalue = ftype.validate(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
reader.readinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
reader.readinto(fvalue)
fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(
LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled
)
else:
raise TypeError # field type is unknown
if fdefault is FLAG_REPEATED:
getattr(msg, fname).append(fvalue)
else:
setattr(msg, fname, fvalue)
for fname, _, _ in fields.values():
if getattr(msg, fname) is FLAG_REQUIRED:
# The message is intended to be user-facing when decoding from wire,
# but not when used internally.
raise ValueError("Required field '{}' was not received".format(fname))
return msg
def dump_message(
writer: Writer, msg: MessageType, field_cache: FieldCache | None = None
) -> None:
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
dump_uvarint(writer.write, fkey)
if ftype is UVarintType:
dump_uvarint(writer.write, svalue)
elif ftype is SVarintType:
dump_uvarint(writer.write, sint_to_uint(svalue))
elif ftype is BoolType:
dump_uvarint(writer.write, int(svalue))
elif isinstance(ftype, EnumType):
dump_uvarint(writer.write, svalue)
elif ftype is BytesType:
if isinstance(svalue, list):
dump_uvarint(writer.write, _count_bytes_list(svalue))
for sub_svalue in svalue:
writer.write(sub_svalue)
else:
dump_uvarint(writer.write, len(svalue))
writer.write(svalue)
elif ftype is UnicodeType:
svalue = svalue.encode()
dump_uvarint(writer.write, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType):
ffields = field_cache.get(ftype)
if ffields is None:
ffields = ftype.get_fields()
field_cache[ftype] = ffields
dump_uvarint(writer.write, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache)
else:
raise TypeError
def count_message(msg: MessageType, field_cache: FieldCache | None = None) -> int:
nbytes = 0
repvalue = [0]
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
# length of all the field keys
nbytes += count_uvarint(fkey) * len(fvalue)
if ftype is UVarintType:
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is SVarintType:
for svalue in fvalue:
nbytes += count_uvarint(sint_to_uint(svalue))
elif ftype is BoolType:
for svalue in fvalue:
nbytes += count_uvarint(int(svalue))
elif isinstance(ftype, EnumType):
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is BytesType:
for svalue in fvalue:
if isinstance(svalue, list):
svalue = _count_bytes_list(svalue)
else:
svalue = len(svalue)
nbytes += count_uvarint(svalue)
nbytes += svalue
elif ftype is UnicodeType:
for svalue in fvalue:
svalue = len(svalue.encode())
nbytes += count_uvarint(svalue)
nbytes += svalue
elif issubclass(ftype, MessageType):
for svalue in fvalue:
fsize = count_message(svalue, field_cache)
nbytes += count_uvarint(fsize)
nbytes += fsize
else:
raise TypeError
return nbytes
def _count_bytes_list(svalue: list[bytes]) -> int:
res = 0
for x in svalue:
res += len(x)
return res

@ -0,0 +1,42 @@
import trezorproto
decode = trezorproto.decode
encode = trezorproto.encode
encoded_length = trezorproto.encoded_length
type_for_name = trezorproto.type_for_name
type_for_wire = trezorproto.type_for_wire
# XXX
# Note that MessageType "subclasses" are not true subclasses, but instead instances
# of the built-in metaclass MsgDef. MessageType instances are in fact instances of
# the built-in type Msg. That is why isinstance checks do not work, and instead the
# MessageTypeSubclass.is_type_of() method must be used.
if False:
from typing import Type, TypeGuard, TypeVar
T = TypeVar("T", bound="MessageType")
class MsgDef(type):
@classmethod
def is_type_of(cls: Type[Type[T]], msg: "MessageType") -> TypeGuard[T]:
"""Identify if the provided message belongs to this type."""
raise NotImplementedError
class MessageType(metaclass=MsgDef):
MESSAGE_NAME: str = "MessageType"
MESSAGE_WIRE_TYPE: int | None = None
def load_message_buffer(
buffer: bytes,
msg_wire_type: int,
experimental_enabled: bool = True,
) -> MessageType:
msg_type = type_for_wire(msg_wire_type)
return decode(buffer, msg_type, experimental_enabled)
def dump_message_buffer(msg: MessageType) -> bytearray:
buffer = bytearray(encoded_length(msg))
encode(buffer, msg)
return buffer

@ -145,10 +145,6 @@ class HashWriter:
def write(self, buf: bytes) -> None: # alias for extend()
self.ctx.update(buf)
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
self.ctx.update(buf)
return len(buf)
def get_digest(self) -> bytes:
return self.ctx.digest()

@ -35,11 +35,10 @@ reads the message's header. When the message type is known the first handler is
"""
import protobuf
from storage.cache import InvalidSessionError
from trezor import log, loop, messages, utils, workflow
from trezor.messages import FailureType
from trezor.messages.Failure import Failure
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor import log, loop, protobuf, utils, workflow
from trezor.wire import codec_v1
from trezor.wire.errors import ActionCancelled, DataError, Error
@ -74,17 +73,19 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
if False:
from typing import Protocol
from typing import Protocol, TypeVar
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class GenericContext(Protocol):
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[protobuf.LoadedMessageType],
expected_type: type[protobuf.MessageType],
) -> Any:
...
async def read(self, expected_type: type[protobuf.LoadedMessageType]) -> Any:
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
...
async def write(self, msg: protobuf.MessageType) -> None:
@ -96,15 +97,14 @@ if False:
def _wrap_protobuf_load(
reader: protobuf.Reader,
expected_type: type[protobuf.LoadedMessageType],
field_cache: protobuf.FieldCache | None = None,
) -> protobuf.LoadedMessageType:
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
return protobuf.load_message(
reader, expected_type, field_cache, experimental_enabled
)
return protobuf.decode(buffer, expected_type, experimental_enabled)
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: {}".format(e.args[0]))
else:
@ -139,19 +139,15 @@ class Context:
self.iface = iface
self.sid = sid
self.buffer = buffer
self.buffer_writer = utils.BufferWriter(self.buffer)
self._field_cache: protobuf.FieldCache = {}
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[protobuf.LoadedMessageType],
field_cache: protobuf.FieldCache | None = None,
) -> protobuf.LoadedMessageType:
await self.write(msg, field_cache)
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
await self.write(msg)
del msg
return await self.read(expected_type, field_cache)
return await self.read(expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int
@ -160,21 +156,17 @@ class Context:
del msg
return await self.read_any(expected_wire_types)
async def read_from_wire(self) -> codec_v1.Message:
return await codec_v1.read_message(self.iface, self.buffer)
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
return codec_v1.read_message(self.iface, self.buffer)
async def read(
self,
expected_type: type[protobuf.LoadedMessageType],
field_cache: protobuf.FieldCache | None = None,
) -> protobuf.LoadedMessageType:
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_type,
expected_type.MESSAGE_NAME,
)
# Load the full message into a buffer, parse out type and data payload
@ -191,13 +183,13 @@ class Context:
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
expected_type,
expected_type.MESSAGE_NAME,
)
workflow.idle_timer.touch()
# look up the protobuf class and parse the message
return _wrap_protobuf_load(msg.data, expected_type, field_cache)
return _wrap_protobuf_load(msg.data, expected_type)
async def read_any(
self, expected_wire_types: Iterable[int]
@ -220,11 +212,15 @@ class Context:
raise UnexpectedMessageError(msg)
# find the protobuf type
exptype = messages.get_type(msg.type)
exptype = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
exptype.MESSAGE_NAME,
)
workflow.idle_timer.touch()
@ -232,41 +228,36 @@ class Context:
# parse the message and return it
return _wrap_protobuf_load(msg.data, exptype)
async def write(
self,
msg: protobuf.MessageType,
field_cache: protobuf.FieldCache | None = None,
) -> None:
async def write(self, msg: protobuf.MessageType) -> None:
if __debug__:
log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
__name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
if field_cache is None:
field_cache = self._field_cache
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
# write the message
msg_size = protobuf.count_message(msg, field_cache)
msg_size = protobuf.encoded_length(msg)
# prepare buffer
if msg_size <= len(self.buffer_writer.buffer):
if msg_size <= len(self.buffer):
# reuse preallocated
buffer_writer = self.buffer_writer
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer_writer = utils.BufferWriter(bytearray(msg_size))
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
buffer_writer.seek(0)
protobuf.dump_message(buffer_writer, msg, field_cache)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer_writer.buffer)[:msg_size],
memoryview(buffer)[:msg_size],
)
# make sure we don't keep around fields of all protobuf types ever
self._field_cache.clear()
def wait(self, *tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result,
@ -301,8 +292,8 @@ async def _handle_single_message(
"""
if __debug__:
try:
msg_type = messages.get_type(msg.type).__name__
except KeyError:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = "%d - unknown message type" % msg.type
log.debug(
__name__,
@ -328,7 +319,7 @@ async def _handle_single_message(
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = messages.get_type(msg.type)
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
@ -424,6 +415,11 @@ async def handle_session(
next_msg = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.

@ -23,7 +23,7 @@ class CodecError(Exception):
class Message:
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if read_and_throw_away:
raise CodecError("Message too large")
return Message(mtype, utils.BufferReader(mdata))
return Message(mtype, mdata)
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:

@ -1,129 +0,0 @@
from common import *
import protobuf
from trezor.utils import BufferReader, BufferWriter
class Message(protobuf.MessageType):
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
self.sint_field = sint_field
self.enum_field = enum_field
@classmethod
def get_fields(cls):
return {
1: ("sint_field", protobuf.SVarintType, 0),
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
}
class MessageWithRequiredAndDefault(protobuf.MessageType):
def __init__(self, required_field, default_field) -> None:
self.required_field = required_field
self.default_field = default_field
@classmethod
def get_fields(cls):
return {
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
2: ("default_field", protobuf.SVarintType, -1),
}
def load_uvarint(data: bytes) -> int:
reader = BufferReader(data)
return protobuf.load_uvarint(reader)
def dump_uvarint(value: int) -> bytearray:
w = bytearray()
protobuf.dump_uvarint(w.extend, value)
return w
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.count_message(msg)
buffer = bytearray(length)
protobuf.dump_message(BufferWriter(buffer), msg)
return buffer
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
return protobuf.load_message(BufferReader(buffer), msg_type)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(ValueError):
dump_uvarint(-1)
def test_load_uvarint(self):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_sint_uint(self):
self.assertEqual(protobuf.uint_to_sint(0), 0)
self.assertEqual(protobuf.sint_to_uint(0), 0)
self.assertEqual(protobuf.sint_to_uint(-1), 1)
self.assertEqual(protobuf.sint_to_uint(1), 2)
self.assertEqual(protobuf.uint_to_sint(1), -1)
self.assertEqual(protobuf.uint_to_sint(2), 1)
# roundtrip:
self.assertEqual(
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
)
self.assertEqual(
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
)
def test_validate_enum(self):
# ok message:
msg = Message(-42, 5)
msg_encoded = dump_message(msg)
nmsg = load_message(Message, msg_encoded)
self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value:
msg = Message(-42, 42)
msg_encoded = dump_message(msg)
with self.assertRaises(TypeError):
load_message(Message, msg_encoded)
def test_required(self):
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, 2)
# try a message without the required_field
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(MessageWithRequiredAndDefault, msg_encoded)
# try a message without the default field
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, -1)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,115 @@
from common import *
from trezor import protobuf
from trezor.messages import WebAuthnCredential, EosAsset, Failure, SignMessage
def load_uvarint32(data: bytes) -> int:
# use known uint32 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, WebAuthnCredential, False)
return msg.index
def load_uvarint64(data: bytes) -> int:
# use known uint64 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, EosAsset, False)
return msg.symbol
def dump_uvarint32(value: int) -> bytearray:
# use known uint32 field in an all-optional message
msg = WebAuthnCredential(index=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
return buffer[1:]
def dump_uvarint64(value: int) -> bytearray:
# use known uint64 field in an all-optional message
msg = EosAsset(symbol=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
return buffer[1:]
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
return buffer
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
return protobuf.decode(buffer, msg_type, False)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(ValueError):
dump_uvarint(-1)
def test_load_uvarint(self):
for load_uvarint in (load_uvarint32, load_uvarint64):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_validate_enum(self):
# ok message:
msg = Failure(code=7)
msg_encoded = dump_message(msg)
nmsg = load_message(Failure, msg_encoded)
self.assertEqual(msg.code, nmsg.code)
# bad enum value:
msg = Failure(code=1000)
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(Failure, msg_encoded)
def test_required(self):
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "foo")
self.assertEqual(nmsg.script_type, 1)
# try a message without the required_field
msg = SignMessage(message=None)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(SignMessage, msg_encoded)
# try a message without the default field
msg = SignMessage(message=b"hello")
msg.coin_name = None
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "Bitcoin")
if __name__ == "__main__":
unittest.main()

@ -53,7 +53,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, b"")
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, b"\x00" * 64)
@ -83,7 +83,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
self.assertEqual(result.data, message)
# message should have been read into the buffer
self.assertEqual(buffer, message)
@ -108,7 +108,7 @@ class TestWireCodecV1(unittest.TestCase):
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
@ -176,7 +176,7 @@ class TestWireCodecV1(unittest.TestCase):
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
self.assertEqual(result.data, message)
def test_read_huge_packet(self):
PACKET_COUNT = 100_000

Loading…
Cancel
Save