mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 14:58:09 +00:00
refactor(core): Switch to new Protobuf API
This commit is contained in:
parent
8a5cb41060
commit
02aa14fc04
@ -145,6 +145,7 @@ async def handle_DoPreauthorized(
|
|||||||
|
|
||||||
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
|
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
|
||||||
|
|
||||||
|
assert req.MESSAGE_WIRE_TYPE is not None
|
||||||
handler = workflow_handlers.find_registered_handler(
|
handler = workflow_handlers.find_registered_handler(
|
||||||
ctx.iface, req.MESSAGE_WIRE_TYPE
|
ctx.iface, req.MESSAGE_WIRE_TYPE
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import protobuf
|
import protobuf
|
||||||
import storage.cache
|
import storage.cache
|
||||||
from trezor import messages, utils
|
from trezor import protobuf
|
||||||
from trezor.messages import MessageType
|
from trezor.enums import MessageType
|
||||||
|
from trezor.utils import ensure
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
@ -16,9 +17,12 @@ def is_set() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def set(auth_message: protobuf.MessageType) -> None:
|
def set(auth_message: protobuf.MessageType) -> None:
|
||||||
buffer = bytearray(protobuf.count_message(auth_message))
|
buffer = protobuf.dump_message_buffer(auth_message)
|
||||||
writer = utils.BufferWriter(buffer)
|
|
||||||
protobuf.dump_message(writer, auth_message)
|
# only wire-level messages can be stored as authorization
|
||||||
|
# (because only wire-level messages have wire_type, which we use as identifier)
|
||||||
|
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
||||||
|
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that mypy knows as well
|
||||||
storage.cache.set(
|
storage.cache.set(
|
||||||
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
|
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
|
||||||
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
||||||
@ -32,11 +36,8 @@ def get() -> protobuf.MessageType | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||||
msg_type = messages.get_type(msg_wire_type)
|
|
||||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
||||||
reader = utils.BufferReader(buffer)
|
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
||||||
|
|
||||||
return protobuf.load_message(reader, msg_type)
|
|
||||||
|
|
||||||
|
|
||||||
def get_wire_types() -> Iterable[int]:
|
def get_wire_types() -> Iterable[int]:
|
||||||
|
@ -50,7 +50,7 @@ async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
|
|||||||
# and then is used to store the derivations if applicable
|
# and then is used to store the derivations if applicable
|
||||||
plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
|
plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
|
||||||
utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
|
utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
|
||||||
del msg.tx_enc_keys
|
msg.tx_enc_keys = b""
|
||||||
|
|
||||||
# If return only derivations do tx_priv * view_pub
|
# If return only derivations do tx_priv * view_pub
|
||||||
if do_deriv:
|
if do_deriv:
|
||||||
|
@ -134,7 +134,7 @@ def gen_hmac_vini(
|
|||||||
used only once and hard to check. I.e., indices in step 2
|
used only once and hard to check. I.e., indices in step 2
|
||||||
are uncheckable, decoy keys in step 9 are just random keys.
|
are uncheckable, decoy keys in step 9 are just random keys.
|
||||||
"""
|
"""
|
||||||
import protobuf
|
from trezor import protobuf
|
||||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||||
|
|
||||||
kwriter = get_keccak_writer()
|
kwriter = get_keccak_writer()
|
||||||
@ -146,7 +146,7 @@ def gen_hmac_vini(
|
|||||||
src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index]
|
src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index]
|
||||||
]
|
]
|
||||||
|
|
||||||
protobuf.dump_message(kwriter, src_entr)
|
kwriter.write(protobuf.dump_message_buffer(src_entr))
|
||||||
src_entr.outputs = real_outputs
|
src_entr.outputs = real_outputs
|
||||||
src_entr.real_out_additional_tx_keys = real_additional
|
src_entr.real_out_additional_tx_keys = real_additional
|
||||||
kwriter.write(vini_bin)
|
kwriter.write(vini_bin)
|
||||||
@ -162,11 +162,11 @@ def gen_hmac_vouti(
|
|||||||
"""
|
"""
|
||||||
Generates HMAC for (TxDestinationEntry[i] || tx.vout[i])
|
Generates HMAC for (TxDestinationEntry[i] || tx.vout[i])
|
||||||
"""
|
"""
|
||||||
import protobuf
|
from trezor import protobuf
|
||||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||||
|
|
||||||
kwriter = get_keccak_writer()
|
kwriter = get_keccak_writer()
|
||||||
protobuf.dump_message(kwriter, dst_entr)
|
kwriter.write(protobuf.dump_message_buffer(dst_entr))
|
||||||
kwriter.write(tx_out_bin)
|
kwriter.write(tx_out_bin)
|
||||||
|
|
||||||
hmac_key_vouti = hmac_key_txout(key, idx)
|
hmac_key_vouti = hmac_key_txout(key, idx)
|
||||||
@ -180,11 +180,11 @@ def gen_hmac_tsxdest(
|
|||||||
"""
|
"""
|
||||||
Generates HMAC for TxDestinationEntry[i]
|
Generates HMAC for TxDestinationEntry[i]
|
||||||
"""
|
"""
|
||||||
import protobuf
|
from trezor import protobuf
|
||||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||||
|
|
||||||
kwriter = get_keccak_writer()
|
kwriter = get_keccak_writer()
|
||||||
protobuf.dump_message(kwriter, dst_entr)
|
kwriter.write(protobuf.dump_message_buffer(dst_entr))
|
||||||
|
|
||||||
hmac_key = hmac_key_txdst(key, idx)
|
hmac_key = hmac_key_txdst(key, idx)
|
||||||
hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest())
|
hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest())
|
||||||
|
@ -116,7 +116,7 @@ async def init_transaction(
|
|||||||
from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck
|
from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck
|
||||||
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
|
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
|
||||||
|
|
||||||
rsig_data = MoneroTransactionRsigData(offload_type=state.rsig_offload)
|
rsig_data = MoneroTransactionRsigData(offload_type=int(state.rsig_offload))
|
||||||
|
|
||||||
return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data)
|
return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data)
|
||||||
|
|
||||||
@ -273,11 +273,11 @@ def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData):
|
|||||||
"""
|
"""
|
||||||
Generate master key H( H(TsxData || tx_priv) || rand )
|
Generate master key H( H(TsxData || tx_priv) || rand )
|
||||||
"""
|
"""
|
||||||
import protobuf
|
from trezor import protobuf
|
||||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||||
|
|
||||||
writer = get_keccak_writer()
|
writer = get_keccak_writer()
|
||||||
protobuf.dump_message(writer, tsx_data)
|
writer.write(protobuf.dump_message_buffer(tsx_data))
|
||||||
writer.write(crypto.encodeint(state.tx_priv))
|
writer.write(crypto.encodeint(state.tx_priv))
|
||||||
|
|
||||||
master_key = crypto.keccak_2hash(
|
master_key = crypto.keccak_2hash(
|
||||||
|
@ -62,9 +62,6 @@ class MemoryReaderWriter:
|
|||||||
|
|
||||||
return nread
|
return nread
|
||||||
|
|
||||||
async def areadinto(self, buf):
|
|
||||||
return self.readinto(buf)
|
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
nwritten = len(buf)
|
nwritten = len(buf)
|
||||||
nall = len(self.buffer)
|
nall = len(self.buffer)
|
||||||
@ -98,9 +95,6 @@ class MemoryReaderWriter:
|
|||||||
self.ndata += nwritten
|
self.ndata += nwritten
|
||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
async def awrite(self, buf):
|
|
||||||
return self.write(buf)
|
|
||||||
|
|
||||||
def get_buffer(self):
|
def get_buffer(self):
|
||||||
mv = memoryview(self.buffer)
|
mv = memoryview(self.buffer)
|
||||||
return mv[self.offset : self.woffset]
|
return mv[self.offset : self.woffset]
|
||||||
|
@ -1,456 +0,0 @@
|
|||||||
"""
|
|
||||||
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
|
||||||
bytes, string, embedded message and repeated fields.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if False:
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
class Reader(Protocol):
|
|
||||||
def readinto(self, buf: bytearray) -> int:
|
|
||||||
"""
|
|
||||||
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Writer(Protocol):
|
|
||||||
def write(self, buf: bytes) -> int:
|
|
||||||
"""
|
|
||||||
Writes all bytes from `buf`, or raises `EOFError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
WriteMethod = Callable[[bytes], Any]
|
|
||||||
|
|
||||||
|
|
||||||
_UVARINT_BUFFER = bytearray(1)
|
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint(reader: Reader) -> int:
|
|
||||||
buffer = _UVARINT_BUFFER
|
|
||||||
result = 0
|
|
||||||
shift = 0
|
|
||||||
byte = 0x80
|
|
||||||
while byte & 0x80:
|
|
||||||
reader.readinto(buffer)
|
|
||||||
byte = buffer[0]
|
|
||||||
result += (byte & 0x7F) << shift
|
|
||||||
shift += 7
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def dump_uvarint(write: WriteMethod, n: int) -> None:
|
|
||||||
if n < 0:
|
|
||||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
|
||||||
buffer = _UVARINT_BUFFER
|
|
||||||
shifted = 1
|
|
||||||
while shifted:
|
|
||||||
shifted = n >> 7
|
|
||||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
|
||||||
write(buffer)
|
|
||||||
n = shifted
|
|
||||||
|
|
||||||
|
|
||||||
def count_uvarint(n: int) -> int:
|
|
||||||
if n < 0:
|
|
||||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
|
||||||
if n <= 0x7F:
|
|
||||||
return 1
|
|
||||||
if n <= 0x3FFF:
|
|
||||||
return 2
|
|
||||||
if n <= 0x1F_FFFF:
|
|
||||||
return 3
|
|
||||||
if n <= 0xFFF_FFFF:
|
|
||||||
return 4
|
|
||||||
if n <= 0x7_FFFF_FFFF:
|
|
||||||
return 5
|
|
||||||
if n <= 0x3FF_FFFF_FFFF:
|
|
||||||
return 6
|
|
||||||
if n <= 0x1_FFFF_FFFF_FFFF:
|
|
||||||
return 7
|
|
||||||
if n <= 0xFF_FFFF_FFFF_FFFF:
|
|
||||||
return 8
|
|
||||||
if n <= 0x7FFF_FFFF_FFFF_FFFF:
|
|
||||||
return 9
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
# protobuf interleaved signed encoding:
|
|
||||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
|
||||||
# the idea is to save the sign in LSbit instead of twos-complement.
|
|
||||||
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
|
|
||||||
#
|
|
||||||
# To achieve this with a twos-complement number:
|
|
||||||
# 1. shift left by 1, leaving LSbit free
|
|
||||||
# 2. if the number is negative, do bitwise negation.
|
|
||||||
# This keeps positive number the same, and converts negative from twos-complement
|
|
||||||
# to the appropriate value, while setting the sign bit.
|
|
||||||
#
|
|
||||||
# The original algorithm makes use of the fact that arithmetic (signed) shift
|
|
||||||
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
|
|
||||||
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
|
|
||||||
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
|
|
||||||
#
|
|
||||||
# But this is harder in Python because we don't natively know the bit size of the number.
|
|
||||||
# So we have to branch on whether the number is negative.
|
|
||||||
|
|
||||||
|
|
||||||
def sint_to_uint(sint: int) -> int:
|
|
||||||
res = sint << 1
|
|
||||||
if sint < 0:
|
|
||||||
res = ~res
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def uint_to_sint(uint: int) -> int:
|
|
||||||
sign = uint & 1
|
|
||||||
res = uint >> 1
|
|
||||||
if sign:
|
|
||||||
res = ~res
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class UVarintType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
|
|
||||||
class SVarintType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
|
|
||||||
class BoolType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
|
|
||||||
class EnumType:
|
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
def __init__(self, name: str, enum_values: Iterable[int]) -> None:
|
|
||||||
self.enum_values = enum_values
|
|
||||||
|
|
||||||
def validate(self, fvalue: int) -> int:
|
|
||||||
if fvalue in self.enum_values:
|
|
||||||
return fvalue
|
|
||||||
else:
|
|
||||||
raise TypeError("Invalid enum value")
|
|
||||||
|
|
||||||
|
|
||||||
class BytesType:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
|
|
||||||
class UnicodeType:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
|
|
||||||
|
|
||||||
if False:
|
|
||||||
MessageTypeDef = Union[
|
|
||||||
type[UVarintType],
|
|
||||||
type[SVarintType],
|
|
||||||
type[BoolType],
|
|
||||||
EnumType,
|
|
||||||
type[BytesType],
|
|
||||||
type[UnicodeType],
|
|
||||||
type["MessageType"],
|
|
||||||
]
|
|
||||||
FieldDef = tuple[str, MessageTypeDef, Any]
|
|
||||||
FieldDict = dict[int, FieldDef]
|
|
||||||
|
|
||||||
FieldCache = dict[type["MessageType"], FieldDict]
|
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageType:
|
|
||||||
WIRE_TYPE = 2
|
|
||||||
UNSTABLE = False
|
|
||||||
|
|
||||||
# Type id for the wire codec.
|
|
||||||
# Technically, not every protobuf message has this.
|
|
||||||
MESSAGE_WIRE_TYPE = -1
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fields(cls) -> FieldDict:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
|
|
||||||
if cls in field_cache:
|
|
||||||
fields = field_cache[cls]
|
|
||||||
else:
|
|
||||||
fields = cls.get_fields()
|
|
||||||
field_cache[cls] = fields
|
|
||||||
|
|
||||||
for _, field_type, _ in fields.values():
|
|
||||||
if isinstance(field_type, MessageType):
|
|
||||||
field_type.cache_subordinate_types(field_cache)
|
|
||||||
|
|
||||||
def __eq__(self, rhs: Any) -> bool:
|
|
||||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<%s>" % self.__class__.__name__
|
|
||||||
|
|
||||||
|
|
||||||
class LimitedReader:
|
|
||||||
def __init__(self, reader: Reader, limit: int) -> None:
|
|
||||||
self.reader = reader
|
|
||||||
self.limit = limit
|
|
||||||
|
|
||||||
def readinto(self, buf: bytearray) -> int:
|
|
||||||
if self.limit < len(buf):
|
|
||||||
raise EOFError
|
|
||||||
else:
|
|
||||||
nread = self.reader.readinto(buf)
|
|
||||||
self.limit -= nread
|
|
||||||
return nread
|
|
||||||
|
|
||||||
|
|
||||||
FLAG_REPEATED = object()
|
|
||||||
FLAG_REQUIRED = object()
|
|
||||||
FLAG_EXPERIMENTAL = object()
|
|
||||||
|
|
||||||
|
|
||||||
def load_message(
|
|
||||||
reader: Reader,
|
|
||||||
msg_type: type[LoadedMessageType],
|
|
||||||
field_cache: FieldCache | None = None,
|
|
||||||
experimental_enabled: bool = True,
|
|
||||||
) -> LoadedMessageType:
|
|
||||||
if field_cache is None:
|
|
||||||
field_cache = {}
|
|
||||||
fields = field_cache.get(msg_type)
|
|
||||||
if fields is None:
|
|
||||||
fields = msg_type.get_fields()
|
|
||||||
field_cache[msg_type] = fields
|
|
||||||
|
|
||||||
if msg_type.UNSTABLE and not experimental_enabled:
|
|
||||||
raise ValueError # experimental messages not enabled
|
|
||||||
|
|
||||||
# we need to avoid calling __init__, which enforces required arguments
|
|
||||||
msg: LoadedMessageType = object.__new__(msg_type)
|
|
||||||
# pre-seed the object with defaults
|
|
||||||
for fname, _, fdefault in fields.values():
|
|
||||||
if fdefault is FLAG_REPEATED:
|
|
||||||
fdefault = []
|
|
||||||
elif fdefault is FLAG_EXPERIMENTAL:
|
|
||||||
fdefault = None
|
|
||||||
setattr(msg, fname, fdefault)
|
|
||||||
|
|
||||||
if False:
|
|
||||||
SingularValue = Union[int, bool, bytearray, str, MessageType]
|
|
||||||
Value = Union[SingularValue, list[SingularValue]]
|
|
||||||
fvalue: Value = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
fkey = load_uvarint(reader)
|
|
||||||
except EOFError:
|
|
||||||
break # no more fields to load
|
|
||||||
|
|
||||||
ftag = fkey >> 3
|
|
||||||
wtype = fkey & 7
|
|
||||||
|
|
||||||
field = fields.get(ftag, None)
|
|
||||||
|
|
||||||
if field is None: # unknown field, skip it
|
|
||||||
if wtype == 0:
|
|
||||||
load_uvarint(reader)
|
|
||||||
elif wtype == 2:
|
|
||||||
ivalue = load_uvarint(reader)
|
|
||||||
reader.readinto(bytearray(ivalue))
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
continue
|
|
||||||
|
|
||||||
fname, ftype, fdefault = field
|
|
||||||
if wtype != ftype.WIRE_TYPE:
|
|
||||||
raise TypeError # parsed wire type differs from the schema
|
|
||||||
|
|
||||||
if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled:
|
|
||||||
raise ValueError # experimental fields not enabled
|
|
||||||
|
|
||||||
ivalue = load_uvarint(reader)
|
|
||||||
|
|
||||||
if ftype is UVarintType:
|
|
||||||
fvalue = ivalue
|
|
||||||
elif ftype is SVarintType:
|
|
||||||
fvalue = uint_to_sint(ivalue)
|
|
||||||
elif ftype is BoolType:
|
|
||||||
fvalue = bool(ivalue)
|
|
||||||
elif isinstance(ftype, EnumType):
|
|
||||||
fvalue = ftype.validate(ivalue)
|
|
||||||
elif ftype is BytesType:
|
|
||||||
fvalue = bytearray(ivalue)
|
|
||||||
reader.readinto(fvalue)
|
|
||||||
elif ftype is UnicodeType:
|
|
||||||
fvalue = bytearray(ivalue)
|
|
||||||
reader.readinto(fvalue)
|
|
||||||
fvalue = bytes(fvalue).decode()
|
|
||||||
elif issubclass(ftype, MessageType):
|
|
||||||
fvalue = load_message(
|
|
||||||
LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError # field type is unknown
|
|
||||||
|
|
||||||
if fdefault is FLAG_REPEATED:
|
|
||||||
getattr(msg, fname).append(fvalue)
|
|
||||||
else:
|
|
||||||
setattr(msg, fname, fvalue)
|
|
||||||
|
|
||||||
for fname, _, _ in fields.values():
|
|
||||||
if getattr(msg, fname) is FLAG_REQUIRED:
|
|
||||||
# The message is intended to be user-facing when decoding from wire,
|
|
||||||
# but not when used internally.
|
|
||||||
raise ValueError("Required field '{}' was not received".format(fname))
|
|
||||||
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def dump_message(
|
|
||||||
writer: Writer, msg: MessageType, field_cache: FieldCache | None = None
|
|
||||||
) -> None:
|
|
||||||
repvalue = [0]
|
|
||||||
|
|
||||||
if field_cache is None:
|
|
||||||
field_cache = {}
|
|
||||||
fields = field_cache.get(type(msg))
|
|
||||||
if fields is None:
|
|
||||||
fields = msg.get_fields()
|
|
||||||
field_cache[type(msg)] = fields
|
|
||||||
|
|
||||||
for ftag in fields:
|
|
||||||
fname, ftype, fdefault = fields[ftag]
|
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
|
||||||
if fvalue is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
|
||||||
|
|
||||||
if fdefault is not FLAG_REPEATED:
|
|
||||||
repvalue[0] = fvalue
|
|
||||||
fvalue = repvalue
|
|
||||||
|
|
||||||
for svalue in fvalue:
|
|
||||||
dump_uvarint(writer.write, fkey)
|
|
||||||
|
|
||||||
if ftype is UVarintType:
|
|
||||||
dump_uvarint(writer.write, svalue)
|
|
||||||
|
|
||||||
elif ftype is SVarintType:
|
|
||||||
dump_uvarint(writer.write, sint_to_uint(svalue))
|
|
||||||
|
|
||||||
elif ftype is BoolType:
|
|
||||||
dump_uvarint(writer.write, int(svalue))
|
|
||||||
|
|
||||||
elif isinstance(ftype, EnumType):
|
|
||||||
dump_uvarint(writer.write, svalue)
|
|
||||||
|
|
||||||
elif ftype is BytesType:
|
|
||||||
if isinstance(svalue, list):
|
|
||||||
dump_uvarint(writer.write, _count_bytes_list(svalue))
|
|
||||||
for sub_svalue in svalue:
|
|
||||||
writer.write(sub_svalue)
|
|
||||||
else:
|
|
||||||
dump_uvarint(writer.write, len(svalue))
|
|
||||||
writer.write(svalue)
|
|
||||||
|
|
||||||
elif ftype is UnicodeType:
|
|
||||||
svalue = svalue.encode()
|
|
||||||
dump_uvarint(writer.write, len(svalue))
|
|
||||||
writer.write(svalue)
|
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
|
||||||
ffields = field_cache.get(ftype)
|
|
||||||
if ffields is None:
|
|
||||||
ffields = ftype.get_fields()
|
|
||||||
field_cache[ftype] = ffields
|
|
||||||
dump_uvarint(writer.write, count_message(svalue, field_cache))
|
|
||||||
dump_message(writer, svalue, field_cache)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
|
|
||||||
def count_message(msg: MessageType, field_cache: FieldCache | None = None) -> int:
|
|
||||||
nbytes = 0
|
|
||||||
repvalue = [0]
|
|
||||||
|
|
||||||
if field_cache is None:
|
|
||||||
field_cache = {}
|
|
||||||
fields = field_cache.get(type(msg))
|
|
||||||
if fields is None:
|
|
||||||
fields = msg.get_fields()
|
|
||||||
field_cache[type(msg)] = fields
|
|
||||||
|
|
||||||
for ftag in fields:
|
|
||||||
fname, ftype, fdefault = fields[ftag]
|
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
|
||||||
if fvalue is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
|
||||||
|
|
||||||
if fdefault is not FLAG_REPEATED:
|
|
||||||
repvalue[0] = fvalue
|
|
||||||
fvalue = repvalue
|
|
||||||
|
|
||||||
# length of all the field keys
|
|
||||||
nbytes += count_uvarint(fkey) * len(fvalue)
|
|
||||||
|
|
||||||
if ftype is UVarintType:
|
|
||||||
for svalue in fvalue:
|
|
||||||
nbytes += count_uvarint(svalue)
|
|
||||||
|
|
||||||
elif ftype is SVarintType:
|
|
||||||
for svalue in fvalue:
|
|
||||||
nbytes += count_uvarint(sint_to_uint(svalue))
|
|
||||||
|
|
||||||
elif ftype is BoolType:
|
|
||||||
for svalue in fvalue:
|
|
||||||
nbytes += count_uvarint(int(svalue))
|
|
||||||
|
|
||||||
elif isinstance(ftype, EnumType):
|
|
||||||
for svalue in fvalue:
|
|
||||||
nbytes += count_uvarint(svalue)
|
|
||||||
|
|
||||||
elif ftype is BytesType:
|
|
||||||
for svalue in fvalue:
|
|
||||||
if isinstance(svalue, list):
|
|
||||||
svalue = _count_bytes_list(svalue)
|
|
||||||
else:
|
|
||||||
svalue = len(svalue)
|
|
||||||
nbytes += count_uvarint(svalue)
|
|
||||||
nbytes += svalue
|
|
||||||
|
|
||||||
elif ftype is UnicodeType:
|
|
||||||
for svalue in fvalue:
|
|
||||||
svalue = len(svalue.encode())
|
|
||||||
nbytes += count_uvarint(svalue)
|
|
||||||
nbytes += svalue
|
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
|
||||||
for svalue in fvalue:
|
|
||||||
fsize = count_message(svalue, field_cache)
|
|
||||||
nbytes += count_uvarint(fsize)
|
|
||||||
nbytes += fsize
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
return nbytes
|
|
||||||
|
|
||||||
|
|
||||||
def _count_bytes_list(svalue: list[bytes]) -> int:
|
|
||||||
res = 0
|
|
||||||
for x in svalue:
|
|
||||||
res += len(x)
|
|
||||||
return res
|
|
42
core/src/trezor/protobuf.py
Normal file
42
core/src/trezor/protobuf.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import trezorproto
|
||||||
|
|
||||||
|
decode = trezorproto.decode
|
||||||
|
encode = trezorproto.encode
|
||||||
|
encoded_length = trezorproto.encoded_length
|
||||||
|
type_for_name = trezorproto.type_for_name
|
||||||
|
type_for_wire = trezorproto.type_for_wire
|
||||||
|
|
||||||
|
# XXX
|
||||||
|
# Note that MessageType "subclasses" are not true subclasses, but instead instances
|
||||||
|
# of the built-in metaclass MsgDef. MessageType instances are in fact instances of
|
||||||
|
# the built-in type Msg. That is why isinstance checks do not work, and instead the
|
||||||
|
# MessageTypeSubclass.is_type_of() method must be used.
|
||||||
|
if False:
|
||||||
|
from typing import Type, TypeGuard, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="MessageType")
|
||||||
|
|
||||||
|
class MsgDef(type):
|
||||||
|
@classmethod
|
||||||
|
def is_type_of(cls: Type[Type[T]], msg: "MessageType") -> TypeGuard[T]:
|
||||||
|
"""Identify if the provided message belongs to this type."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class MessageType(metaclass=MsgDef):
|
||||||
|
MESSAGE_NAME: str = "MessageType"
|
||||||
|
MESSAGE_WIRE_TYPE: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_message_buffer(
|
||||||
|
buffer: bytes,
|
||||||
|
msg_wire_type: int,
|
||||||
|
experimental_enabled: bool = True,
|
||||||
|
) -> MessageType:
|
||||||
|
msg_type = type_for_wire(msg_wire_type)
|
||||||
|
return decode(buffer, msg_type, experimental_enabled)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_message_buffer(msg: MessageType) -> bytearray:
|
||||||
|
buffer = bytearray(encoded_length(msg))
|
||||||
|
encode(buffer, msg)
|
||||||
|
return buffer
|
@ -145,10 +145,6 @@ class HashWriter:
|
|||||||
def write(self, buf: bytes) -> None: # alias for extend()
|
def write(self, buf: bytes) -> None: # alias for extend()
|
||||||
self.ctx.update(buf)
|
self.ctx.update(buf)
|
||||||
|
|
||||||
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
|
|
||||||
self.ctx.update(buf)
|
|
||||||
return len(buf)
|
|
||||||
|
|
||||||
def get_digest(self) -> bytes:
|
def get_digest(self) -> bytes:
|
||||||
return self.ctx.digest()
|
return self.ctx.digest()
|
||||||
|
|
||||||
|
@ -35,11 +35,10 @@ reads the message's header. When the message type is known the first handler is
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import protobuf
|
|
||||||
from storage.cache import InvalidSessionError
|
from storage.cache import InvalidSessionError
|
||||||
from trezor import log, loop, messages, utils, workflow
|
from trezor.enums import FailureType
|
||||||
from trezor.messages import FailureType
|
from trezor.messages import Failure
|
||||||
from trezor.messages.Failure import Failure
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
from trezor.wire import codec_v1
|
from trezor.wire import codec_v1
|
||||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||||
|
|
||||||
@ -74,17 +73,19 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Protocol
|
from typing import Protocol, TypeVar
|
||||||
|
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
class GenericContext(Protocol):
|
class GenericContext(Protocol):
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
expected_type: type[protobuf.LoadedMessageType],
|
expected_type: type[protobuf.MessageType],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def read(self, expected_type: type[protobuf.LoadedMessageType]) -> Any:
|
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
@ -96,15 +97,14 @@ if False:
|
|||||||
|
|
||||||
|
|
||||||
def _wrap_protobuf_load(
|
def _wrap_protobuf_load(
|
||||||
reader: protobuf.Reader,
|
buffer: bytes,
|
||||||
expected_type: type[protobuf.LoadedMessageType],
|
expected_type: type[LoadedMessageType],
|
||||||
field_cache: protobuf.FieldCache | None = None,
|
) -> LoadedMessageType:
|
||||||
) -> protobuf.LoadedMessageType:
|
|
||||||
try:
|
try:
|
||||||
return protobuf.load_message(
|
return protobuf.decode(buffer, expected_type, experimental_enabled)
|
||||||
reader, expected_type, field_cache, experimental_enabled
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if __debug__:
|
||||||
|
log.exception(__name__, e)
|
||||||
if e.args:
|
if e.args:
|
||||||
raise DataError("Failed to decode message: {}".format(e.args[0]))
|
raise DataError("Failed to decode message: {}".format(e.args[0]))
|
||||||
else:
|
else:
|
||||||
@ -139,19 +139,15 @@ class Context:
|
|||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.buffer_writer = utils.BufferWriter(self.buffer)
|
|
||||||
|
|
||||||
self._field_cache: protobuf.FieldCache = {}
|
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
expected_type: type[protobuf.LoadedMessageType],
|
expected_type: type[LoadedMessageType],
|
||||||
field_cache: protobuf.FieldCache | None = None,
|
) -> LoadedMessageType:
|
||||||
) -> protobuf.LoadedMessageType:
|
await self.write(msg)
|
||||||
await self.write(msg, field_cache)
|
|
||||||
del msg
|
del msg
|
||||||
return await self.read(expected_type, field_cache)
|
return await self.read(expected_type)
|
||||||
|
|
||||||
async def call_any(
|
async def call_any(
|
||||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
self, msg: protobuf.MessageType, *expected_wire_types: int
|
||||||
@ -160,21 +156,17 @@ class Context:
|
|||||||
del msg
|
del msg
|
||||||
return await self.read_any(expected_wire_types)
|
return await self.read_any(expected_wire_types)
|
||||||
|
|
||||||
async def read_from_wire(self) -> codec_v1.Message:
|
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||||
return await codec_v1.read_message(self.iface, self.buffer)
|
return codec_v1.read_message(self.iface, self.buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
|
||||||
self,
|
|
||||||
expected_type: type[protobuf.LoadedMessageType],
|
|
||||||
field_cache: protobuf.FieldCache | None = None,
|
|
||||||
) -> protobuf.LoadedMessageType:
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"%s:%x expect: %s",
|
"%s:%x expect: %s",
|
||||||
self.iface.iface_num(),
|
self.iface.iface_num(),
|
||||||
self.sid,
|
self.sid,
|
||||||
expected_type,
|
expected_type.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the full message into a buffer, parse out type and data payload
|
# Load the full message into a buffer, parse out type and data payload
|
||||||
@ -191,13 +183,13 @@ class Context:
|
|||||||
"%s:%x read: %s",
|
"%s:%x read: %s",
|
||||||
self.iface.iface_num(),
|
self.iface.iface_num(),
|
||||||
self.sid,
|
self.sid,
|
||||||
expected_type,
|
expected_type.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.idle_timer.touch()
|
workflow.idle_timer.touch()
|
||||||
|
|
||||||
# look up the protobuf class and parse the message
|
# look up the protobuf class and parse the message
|
||||||
return _wrap_protobuf_load(msg.data, expected_type, field_cache)
|
return _wrap_protobuf_load(msg.data, expected_type)
|
||||||
|
|
||||||
async def read_any(
|
async def read_any(
|
||||||
self, expected_wire_types: Iterable[int]
|
self, expected_wire_types: Iterable[int]
|
||||||
@ -220,11 +212,15 @@ class Context:
|
|||||||
raise UnexpectedMessageError(msg)
|
raise UnexpectedMessageError(msg)
|
||||||
|
|
||||||
# find the protobuf type
|
# find the protobuf type
|
||||||
exptype = messages.get_type(msg.type)
|
exptype = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
|
__name__,
|
||||||
|
"%s:%x read: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
self.sid,
|
||||||
|
exptype.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.idle_timer.touch()
|
workflow.idle_timer.touch()
|
||||||
@ -232,41 +228,36 @@ class Context:
|
|||||||
# parse the message and return it
|
# parse the message and return it
|
||||||
return _wrap_protobuf_load(msg.data, exptype)
|
return _wrap_protobuf_load(msg.data, exptype)
|
||||||
|
|
||||||
async def write(
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
self,
|
|
||||||
msg: protobuf.MessageType,
|
|
||||||
field_cache: protobuf.FieldCache | None = None,
|
|
||||||
) -> None:
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
__name__,
|
||||||
|
"%s:%x write: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
self.sid,
|
||||||
|
msg.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
if field_cache is None:
|
# cannot write message without wire type
|
||||||
field_cache = self._field_cache
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||||
|
|
||||||
# write the message
|
msg_size = protobuf.encoded_length(msg)
|
||||||
msg_size = protobuf.count_message(msg, field_cache)
|
|
||||||
|
|
||||||
# prepare buffer
|
if msg_size <= len(self.buffer):
|
||||||
if msg_size <= len(self.buffer_writer.buffer):
|
|
||||||
# reuse preallocated
|
# reuse preallocated
|
||||||
buffer_writer = self.buffer_writer
|
buffer = self.buffer
|
||||||
else:
|
else:
|
||||||
# message is too big, we need to allocate a new buffer
|
# message is too big, we need to allocate a new buffer
|
||||||
buffer_writer = utils.BufferWriter(bytearray(msg_size))
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
msg_size = protobuf.encode(buffer, msg)
|
||||||
|
|
||||||
buffer_writer.seek(0)
|
|
||||||
protobuf.dump_message(buffer_writer, msg, field_cache)
|
|
||||||
await codec_v1.write_message(
|
await codec_v1.write_message(
|
||||||
self.iface,
|
self.iface,
|
||||||
msg.MESSAGE_WIRE_TYPE,
|
msg.MESSAGE_WIRE_TYPE,
|
||||||
memoryview(buffer_writer.buffer)[:msg_size],
|
memoryview(buffer)[:msg_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure we don't keep around fields of all protobuf types ever
|
|
||||||
self._field_cache.clear()
|
|
||||||
|
|
||||||
def wait(self, *tasks: Awaitable) -> Any:
|
def wait(self, *tasks: Awaitable) -> Any:
|
||||||
"""
|
"""
|
||||||
Wait until one of the passed tasks finishes, and return the result,
|
Wait until one of the passed tasks finishes, and return the result,
|
||||||
@ -301,8 +292,8 @@ async def _handle_single_message(
|
|||||||
"""
|
"""
|
||||||
if __debug__:
|
if __debug__:
|
||||||
try:
|
try:
|
||||||
msg_type = messages.get_type(msg.type).__name__
|
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||||
except KeyError:
|
except Exception:
|
||||||
msg_type = "%d - unknown message type" % msg.type
|
msg_type = "%d - unknown message type" % msg.type
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -328,7 +319,7 @@ async def _handle_single_message(
|
|||||||
try:
|
try:
|
||||||
# Find a protobuf.MessageType subclass that describes this
|
# Find a protobuf.MessageType subclass that describes this
|
||||||
# message. Raises if the type is not found.
|
# message. Raises if the type is not found.
|
||||||
req_type = messages.get_type(msg.type)
|
req_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
# Try to decode the message according to schema from
|
# Try to decode the message according to schema from
|
||||||
# `req_type`. Raises if the message is malformed.
|
# `req_type`. Raises if the message is malformed.
|
||||||
@ -424,6 +415,11 @@ async def handle_session(
|
|||||||
next_msg = await _handle_single_message(
|
next_msg = await _handle_single_message(
|
||||||
ctx, msg, use_workflow=not is_debug_session
|
ctx, msg, use_workflow=not is_debug_session
|
||||||
)
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
# Log and ignore. The session handler can only exit explicitly in the
|
||||||
|
# following finally block.
|
||||||
|
if __debug__:
|
||||||
|
log.exception(__name__, exc)
|
||||||
finally:
|
finally:
|
||||||
if not __debug__ or not is_debug_session:
|
if not __debug__ or not is_debug_session:
|
||||||
# Unload modules imported by the workflow. Should not raise.
|
# Unload modules imported by the workflow. Should not raise.
|
||||||
|
@ -23,7 +23,7 @@ class CodecError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
|
def __init__(self, mtype: int, mdata: bytes) -> None:
|
||||||
self.type = mtype
|
self.type = mtype
|
||||||
self.data = mdata
|
self.data = mdata
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if read_and_throw_away:
|
if read_and_throw_away:
|
||||||
raise CodecError("Message too large")
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
return Message(mtype, utils.BufferReader(mdata))
|
return Message(mtype, mdata)
|
||||||
|
|
||||||
|
|
||||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||||
|
@ -1,129 +0,0 @@
|
|||||||
from common import *
|
|
||||||
|
|
||||||
import protobuf
|
|
||||||
from trezor.utils import BufferReader, BufferWriter
|
|
||||||
|
|
||||||
|
|
||||||
class Message(protobuf.MessageType):
|
|
||||||
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
|
|
||||||
self.sint_field = sint_field
|
|
||||||
self.enum_field = enum_field
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fields(cls):
|
|
||||||
return {
|
|
||||||
1: ("sint_field", protobuf.SVarintType, 0),
|
|
||||||
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MessageWithRequiredAndDefault(protobuf.MessageType):
|
|
||||||
def __init__(self, required_field, default_field) -> None:
|
|
||||||
self.required_field = required_field
|
|
||||||
self.default_field = default_field
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fields(cls):
|
|
||||||
return {
|
|
||||||
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
|
||||||
2: ("default_field", protobuf.SVarintType, -1),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint(data: bytes) -> int:
|
|
||||||
reader = BufferReader(data)
|
|
||||||
return protobuf.load_uvarint(reader)
|
|
||||||
|
|
||||||
|
|
||||||
def dump_uvarint(value: int) -> bytearray:
|
|
||||||
w = bytearray()
|
|
||||||
protobuf.dump_uvarint(w.extend, value)
|
|
||||||
return w
|
|
||||||
|
|
||||||
|
|
||||||
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
|
||||||
length = protobuf.count_message(msg)
|
|
||||||
buffer = bytearray(length)
|
|
||||||
protobuf.dump_message(BufferWriter(buffer), msg)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
|
|
||||||
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
|
|
||||||
return protobuf.load_message(BufferReader(buffer), msg_type)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProtobuf(unittest.TestCase):
|
|
||||||
def test_dump_uvarint(self):
|
|
||||||
self.assertEqual(dump_uvarint(0), b"\x00")
|
|
||||||
self.assertEqual(dump_uvarint(1), b"\x01")
|
|
||||||
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
|
|
||||||
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
dump_uvarint(-1)
|
|
||||||
|
|
||||||
def test_load_uvarint(self):
|
|
||||||
self.assertEqual(load_uvarint(b"\x00"), 0)
|
|
||||||
self.assertEqual(load_uvarint(b"\x01"), 1)
|
|
||||||
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
|
||||||
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
|
||||||
|
|
||||||
def test_sint_uint(self):
|
|
||||||
self.assertEqual(protobuf.uint_to_sint(0), 0)
|
|
||||||
self.assertEqual(protobuf.sint_to_uint(0), 0)
|
|
||||||
|
|
||||||
self.assertEqual(protobuf.sint_to_uint(-1), 1)
|
|
||||||
self.assertEqual(protobuf.sint_to_uint(1), 2)
|
|
||||||
|
|
||||||
self.assertEqual(protobuf.uint_to_sint(1), -1)
|
|
||||||
self.assertEqual(protobuf.uint_to_sint(2), 1)
|
|
||||||
|
|
||||||
# roundtrip:
|
|
||||||
self.assertEqual(
|
|
||||||
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_validate_enum(self):
|
|
||||||
# ok message:
|
|
||||||
msg = Message(-42, 5)
|
|
||||||
msg_encoded = dump_message(msg)
|
|
||||||
nmsg = load_message(Message, msg_encoded)
|
|
||||||
|
|
||||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
|
||||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
|
||||||
|
|
||||||
# bad enum value:
|
|
||||||
msg = Message(-42, 42)
|
|
||||||
msg_encoded = dump_message(msg)
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
load_message(Message, msg_encoded)
|
|
||||||
|
|
||||||
def test_required(self):
|
|
||||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
|
|
||||||
msg_encoded = dump_message(msg)
|
|
||||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
||||||
|
|
||||||
self.assertEqual(nmsg.required_field, 1)
|
|
||||||
self.assertEqual(nmsg.default_field, 2)
|
|
||||||
|
|
||||||
# try a message without the required_field
|
|
||||||
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
|
|
||||||
# encoding always succeeds
|
|
||||||
msg_encoded = dump_message(msg)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
||||||
|
|
||||||
# try a message without the default field
|
|
||||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
|
|
||||||
msg_encoded = dump_message(msg)
|
|
||||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
||||||
|
|
||||||
self.assertEqual(nmsg.required_field, 1)
|
|
||||||
self.assertEqual(nmsg.default_field, -1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
115
core/tests/test_trezor.protobuf.py
Normal file
115
core/tests/test_trezor.protobuf.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
from common import *
|
||||||
|
|
||||||
|
from trezor import protobuf
|
||||||
|
from trezor.messages import WebAuthnCredential, EosAsset, Failure, SignMessage
|
||||||
|
|
||||||
|
|
||||||
|
def load_uvarint32(data: bytes) -> int:
|
||||||
|
# use known uint32 field in an all-optional message
|
||||||
|
buffer = bytearray(len(data) + 1)
|
||||||
|
buffer[1:] = data
|
||||||
|
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
|
||||||
|
msg = protobuf.decode(buffer, WebAuthnCredential, False)
|
||||||
|
return msg.index
|
||||||
|
|
||||||
|
|
||||||
|
def load_uvarint64(data: bytes) -> int:
|
||||||
|
# use known uint64 field in an all-optional message
|
||||||
|
buffer = bytearray(len(data) + 1)
|
||||||
|
buffer[1:] = data
|
||||||
|
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
|
||||||
|
msg = protobuf.decode(buffer, EosAsset, False)
|
||||||
|
return msg.symbol
|
||||||
|
|
||||||
|
|
||||||
|
def dump_uvarint32(value: int) -> bytearray:
|
||||||
|
# use known uint32 field in an all-optional message
|
||||||
|
msg = WebAuthnCredential(index=value)
|
||||||
|
length = protobuf.encoded_length(msg)
|
||||||
|
buffer = bytearray(length)
|
||||||
|
protobuf.encode(buffer, msg)
|
||||||
|
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
|
||||||
|
return buffer[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def dump_uvarint64(value: int) -> bytearray:
|
||||||
|
# use known uint64 field in an all-optional message
|
||||||
|
msg = EosAsset(symbol=value)
|
||||||
|
length = protobuf.encoded_length(msg)
|
||||||
|
buffer = bytearray(length)
|
||||||
|
protobuf.encode(buffer, msg)
|
||||||
|
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
|
||||||
|
return buffer[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||||
|
length = protobuf.encoded_length(msg)
|
||||||
|
buffer = bytearray(length)
|
||||||
|
protobuf.encode(buffer, msg)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
|
||||||
|
return protobuf.decode(buffer, msg_type, False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProtobuf(unittest.TestCase):
|
||||||
|
def test_dump_uvarint(self):
|
||||||
|
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
|
||||||
|
self.assertEqual(dump_uvarint(0), b"\x00")
|
||||||
|
self.assertEqual(dump_uvarint(1), b"\x01")
|
||||||
|
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
|
||||||
|
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
dump_uvarint(-1)
|
||||||
|
|
||||||
|
def test_load_uvarint(self):
|
||||||
|
for load_uvarint in (load_uvarint32, load_uvarint64):
|
||||||
|
self.assertEqual(load_uvarint(b"\x00"), 0)
|
||||||
|
self.assertEqual(load_uvarint(b"\x01"), 1)
|
||||||
|
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
||||||
|
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
||||||
|
|
||||||
|
def test_validate_enum(self):
|
||||||
|
# ok message:
|
||||||
|
msg = Failure(code=7)
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
nmsg = load_message(Failure, msg_encoded)
|
||||||
|
|
||||||
|
self.assertEqual(msg.code, nmsg.code)
|
||||||
|
|
||||||
|
# bad enum value:
|
||||||
|
msg = Failure(code=1000)
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
load_message(Failure, msg_encoded)
|
||||||
|
|
||||||
|
def test_required(self):
|
||||||
|
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
nmsg = load_message(SignMessage, msg_encoded)
|
||||||
|
|
||||||
|
self.assertEqual(nmsg.message, b"hello")
|
||||||
|
self.assertEqual(nmsg.coin_name, "foo")
|
||||||
|
self.assertEqual(nmsg.script_type, 1)
|
||||||
|
|
||||||
|
# try a message without the required_field
|
||||||
|
msg = SignMessage(message=None)
|
||||||
|
# encoding always succeeds
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
load_message(SignMessage, msg_encoded)
|
||||||
|
|
||||||
|
# try a message without the default field
|
||||||
|
msg = SignMessage(message=b"hello")
|
||||||
|
msg.coin_name = None
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
nmsg = load_message(SignMessage, msg_encoded)
|
||||||
|
|
||||||
|
self.assertEqual(nmsg.message, b"hello")
|
||||||
|
self.assertEqual(nmsg.coin_name, "Bitcoin")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -53,7 +53,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
self.assertEqual(result.data.buffer, b"")
|
self.assertEqual(result.data, b"")
|
||||||
|
|
||||||
# message should have been read into the buffer
|
# message should have been read into the buffer
|
||||||
self.assertEqual(buffer, b"\x00" * 64)
|
self.assertEqual(buffer, b"\x00" * 64)
|
||||||
@ -83,7 +83,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
self.assertEqual(result.data.buffer, message)
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
# message should have been read into the buffer
|
# message should have been read into the buffer
|
||||||
self.assertEqual(buffer, message)
|
self.assertEqual(buffer, message)
|
||||||
@ -108,7 +108,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
self.assertEqual(result.data.buffer, message)
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
# read should have allocated its own buffer and not touch ours
|
# read should have allocated its own buffer and not touch ours
|
||||||
self.assertEqual(buffer, b"\x00")
|
self.assertEqual(buffer, b"\x00")
|
||||||
@ -176,7 +176,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
|
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
self.assertEqual(result.data.buffer, message)
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
def test_read_huge_packet(self):
|
def test_read_huge_packet(self):
|
||||||
PACKET_COUNT = 100_000
|
PACKET_COUNT = 100_000
|
||||||
|
Loading…
Reference in New Issue
Block a user