mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 14:58:09 +00:00
refactor(core): Switch to new Protobuf API
This commit is contained in:
parent
8a5cb41060
commit
02aa14fc04
@ -145,6 +145,7 @@ async def handle_DoPreauthorized(
|
||||
|
||||
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
|
||||
|
||||
assert req.MESSAGE_WIRE_TYPE is not None
|
||||
handler = workflow_handlers.find_registered_handler(
|
||||
ctx.iface, req.MESSAGE_WIRE_TYPE
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
import protobuf
|
||||
import storage.cache
|
||||
from trezor import messages, utils
|
||||
from trezor.messages import MessageType
|
||||
from trezor import protobuf
|
||||
from trezor.enums import MessageType
|
||||
from trezor.utils import ensure
|
||||
|
||||
if False:
|
||||
from typing import Iterable
|
||||
@ -16,9 +17,12 @@ def is_set() -> bool:
|
||||
|
||||
|
||||
def set(auth_message: protobuf.MessageType) -> None:
|
||||
buffer = bytearray(protobuf.count_message(auth_message))
|
||||
writer = utils.BufferWriter(buffer)
|
||||
protobuf.dump_message(writer, auth_message)
|
||||
buffer = protobuf.dump_message_buffer(auth_message)
|
||||
|
||||
# only wire-level messages can be stored as authorization
|
||||
# (because only wire-level messages have wire_type, which we use as identifier)
|
||||
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
||||
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that mypy knows as well
|
||||
storage.cache.set(
|
||||
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
|
||||
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
||||
@ -32,11 +36,8 @@ def get() -> protobuf.MessageType | None:
|
||||
return None
|
||||
|
||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||
msg_type = messages.get_type(msg_wire_type)
|
||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
||||
reader = utils.BufferReader(buffer)
|
||||
|
||||
return protobuf.load_message(reader, msg_type)
|
||||
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
||||
|
||||
|
||||
def get_wire_types() -> Iterable[int]:
|
||||
|
@ -50,7 +50,7 @@ async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
|
||||
# and then is used to store the derivations if applicable
|
||||
plain_buff = chacha_poly.decrypt_pack(tx_enc_key, msg.tx_enc_keys)
|
||||
utils.ensure(len(plain_buff) % 32 == 0, "Tx key buffer has invalid size")
|
||||
del msg.tx_enc_keys
|
||||
msg.tx_enc_keys = b""
|
||||
|
||||
# If return only derivations do tx_priv * view_pub
|
||||
if do_deriv:
|
||||
|
@ -134,7 +134,7 @@ def gen_hmac_vini(
|
||||
used only once and hard to check. I.e., indices in step 2
|
||||
are uncheckable, decoy keys in step 9 are just random keys.
|
||||
"""
|
||||
import protobuf
|
||||
from trezor import protobuf
|
||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||
|
||||
kwriter = get_keccak_writer()
|
||||
@ -146,7 +146,7 @@ def gen_hmac_vini(
|
||||
src_entr.real_out_additional_tx_keys[src_entr.real_output_in_tx_index]
|
||||
]
|
||||
|
||||
protobuf.dump_message(kwriter, src_entr)
|
||||
kwriter.write(protobuf.dump_message_buffer(src_entr))
|
||||
src_entr.outputs = real_outputs
|
||||
src_entr.real_out_additional_tx_keys = real_additional
|
||||
kwriter.write(vini_bin)
|
||||
@ -162,11 +162,11 @@ def gen_hmac_vouti(
|
||||
"""
|
||||
Generates HMAC for (TxDestinationEntry[i] || tx.vout[i])
|
||||
"""
|
||||
import protobuf
|
||||
from trezor import protobuf
|
||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||
|
||||
kwriter = get_keccak_writer()
|
||||
protobuf.dump_message(kwriter, dst_entr)
|
||||
kwriter.write(protobuf.dump_message_buffer(dst_entr))
|
||||
kwriter.write(tx_out_bin)
|
||||
|
||||
hmac_key_vouti = hmac_key_txout(key, idx)
|
||||
@ -180,11 +180,11 @@ def gen_hmac_tsxdest(
|
||||
"""
|
||||
Generates HMAC for TxDestinationEntry[i]
|
||||
"""
|
||||
import protobuf
|
||||
from trezor import protobuf
|
||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||
|
||||
kwriter = get_keccak_writer()
|
||||
protobuf.dump_message(kwriter, dst_entr)
|
||||
kwriter.write(protobuf.dump_message_buffer(dst_entr))
|
||||
|
||||
hmac_key = hmac_key_txdst(key, idx)
|
||||
hmac_tsxdest = crypto.compute_hmac(hmac_key, kwriter.get_digest())
|
||||
|
@ -116,7 +116,7 @@ async def init_transaction(
|
||||
from trezor.messages.MoneroTransactionInitAck import MoneroTransactionInitAck
|
||||
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
|
||||
|
||||
rsig_data = MoneroTransactionRsigData(offload_type=state.rsig_offload)
|
||||
rsig_data = MoneroTransactionRsigData(offload_type=int(state.rsig_offload))
|
||||
|
||||
return MoneroTransactionInitAck(hmacs=hmacs, rsig_data=rsig_data)
|
||||
|
||||
@ -273,11 +273,11 @@ def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData):
|
||||
"""
|
||||
Generate master key H( H(TsxData || tx_priv) || rand )
|
||||
"""
|
||||
import protobuf
|
||||
from trezor import protobuf
|
||||
from apps.monero.xmr.keccak_hasher import get_keccak_writer
|
||||
|
||||
writer = get_keccak_writer()
|
||||
protobuf.dump_message(writer, tsx_data)
|
||||
writer.write(protobuf.dump_message_buffer(tsx_data))
|
||||
writer.write(crypto.encodeint(state.tx_priv))
|
||||
|
||||
master_key = crypto.keccak_2hash(
|
||||
|
@ -62,9 +62,6 @@ class MemoryReaderWriter:
|
||||
|
||||
return nread
|
||||
|
||||
async def areadinto(self, buf):
|
||||
return self.readinto(buf)
|
||||
|
||||
def write(self, buf):
|
||||
nwritten = len(buf)
|
||||
nall = len(self.buffer)
|
||||
@ -98,9 +95,6 @@ class MemoryReaderWriter:
|
||||
self.ndata += nwritten
|
||||
return nwritten
|
||||
|
||||
async def awrite(self, buf):
|
||||
return self.write(buf)
|
||||
|
||||
def get_buffer(self):
|
||||
mv = memoryview(self.buffer)
|
||||
return mv[self.offset : self.woffset]
|
||||
|
@ -1,456 +0,0 @@
|
||||
"""
|
||||
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
||||
bytes, string, embedded message and repeated fields.
|
||||
"""
|
||||
|
||||
if False:
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class Reader(Protocol):
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
"""
|
||||
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
|
||||
"""
|
||||
|
||||
class Writer(Protocol):
|
||||
def write(self, buf: bytes) -> int:
|
||||
"""
|
||||
Writes all bytes from `buf`, or raises `EOFError`.
|
||||
"""
|
||||
|
||||
WriteMethod = Callable[[bytes], Any]
|
||||
|
||||
|
||||
_UVARINT_BUFFER = bytearray(1)
|
||||
|
||||
|
||||
def load_uvarint(reader: Reader) -> int:
|
||||
buffer = _UVARINT_BUFFER
|
||||
result = 0
|
||||
shift = 0
|
||||
byte = 0x80
|
||||
while byte & 0x80:
|
||||
reader.readinto(buffer)
|
||||
byte = buffer[0]
|
||||
result += (byte & 0x7F) << shift
|
||||
shift += 7
|
||||
return result
|
||||
|
||||
|
||||
def dump_uvarint(write: WriteMethod, n: int) -> None:
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
buffer = _UVARINT_BUFFER
|
||||
shifted = 1
|
||||
while shifted:
|
||||
shifted = n >> 7
|
||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||
write(buffer)
|
||||
n = shifted
|
||||
|
||||
|
||||
def count_uvarint(n: int) -> int:
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
if n <= 0x7F:
|
||||
return 1
|
||||
if n <= 0x3FFF:
|
||||
return 2
|
||||
if n <= 0x1F_FFFF:
|
||||
return 3
|
||||
if n <= 0xFFF_FFFF:
|
||||
return 4
|
||||
if n <= 0x7_FFFF_FFFF:
|
||||
return 5
|
||||
if n <= 0x3FF_FFFF_FFFF:
|
||||
return 6
|
||||
if n <= 0x1_FFFF_FFFF_FFFF:
|
||||
return 7
|
||||
if n <= 0xFF_FFFF_FFFF_FFFF:
|
||||
return 8
|
||||
if n <= 0x7FFF_FFFF_FFFF_FFFF:
|
||||
return 9
|
||||
raise ValueError
|
||||
|
||||
|
||||
# protobuf interleaved signed encoding:
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
# the idea is to save the sign in LSbit instead of twos-complement.
|
||||
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
|
||||
#
|
||||
# To achieve this with a twos-complement number:
|
||||
# 1. shift left by 1, leaving LSbit free
|
||||
# 2. if the number is negative, do bitwise negation.
|
||||
# This keeps positive number the same, and converts negative from twos-complement
|
||||
# to the appropriate value, while setting the sign bit.
|
||||
#
|
||||
# The original algorithm makes use of the fact that arithmetic (signed) shift
|
||||
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
|
||||
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
|
||||
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
|
||||
#
|
||||
# But this is harder in Python because we don't natively know the bit size of the number.
|
||||
# So we have to branch on whether the number is negative.
|
||||
|
||||
|
||||
def sint_to_uint(sint: int) -> int:
|
||||
res = sint << 1
|
||||
if sint < 0:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
def uint_to_sint(uint: int) -> int:
|
||||
sign = uint & 1
|
||||
res = uint >> 1
|
||||
if sign:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
class UVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class SVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class BoolType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class EnumType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
def __init__(self, name: str, enum_values: Iterable[int]) -> None:
|
||||
self.enum_values = enum_values
|
||||
|
||||
def validate(self, fvalue: int) -> int:
|
||||
if fvalue in self.enum_values:
|
||||
return fvalue
|
||||
else:
|
||||
raise TypeError("Invalid enum value")
|
||||
|
||||
|
||||
class BytesType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
|
||||
class UnicodeType:
|
||||
WIRE_TYPE = 2
|
||||
|
||||
|
||||
if False:
|
||||
MessageTypeDef = Union[
|
||||
type[UVarintType],
|
||||
type[SVarintType],
|
||||
type[BoolType],
|
||||
EnumType,
|
||||
type[BytesType],
|
||||
type[UnicodeType],
|
||||
type["MessageType"],
|
||||
]
|
||||
FieldDef = tuple[str, MessageTypeDef, Any]
|
||||
FieldDict = dict[int, FieldDef]
|
||||
|
||||
FieldCache = dict[type["MessageType"], FieldDict]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound="MessageType")
|
||||
|
||||
|
||||
class MessageType:
|
||||
WIRE_TYPE = 2
|
||||
UNSTABLE = False
|
||||
|
||||
# Type id for the wire codec.
|
||||
# Technically, not every protobuf message has this.
|
||||
MESSAGE_WIRE_TYPE = -1
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> FieldDict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def cache_subordinate_types(cls, field_cache: FieldCache) -> None:
|
||||
if cls in field_cache:
|
||||
fields = field_cache[cls]
|
||||
else:
|
||||
fields = cls.get_fields()
|
||||
field_cache[cls] = fields
|
||||
|
||||
for _, field_type, _ in fields.values():
|
||||
if isinstance(field_type, MessageType):
|
||||
field_type.cache_subordinate_types(field_cache)
|
||||
|
||||
def __eq__(self, rhs: Any) -> bool:
|
||||
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<%s>" % self.__class__.__name__
|
||||
|
||||
|
||||
class LimitedReader:
|
||||
def __init__(self, reader: Reader, limit: int) -> None:
|
||||
self.reader = reader
|
||||
self.limit = limit
|
||||
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
if self.limit < len(buf):
|
||||
raise EOFError
|
||||
else:
|
||||
nread = self.reader.readinto(buf)
|
||||
self.limit -= nread
|
||||
return nread
|
||||
|
||||
|
||||
FLAG_REPEATED = object()
|
||||
FLAG_REQUIRED = object()
|
||||
FLAG_EXPERIMENTAL = object()
|
||||
|
||||
|
||||
def load_message(
|
||||
reader: Reader,
|
||||
msg_type: type[LoadedMessageType],
|
||||
field_cache: FieldCache | None = None,
|
||||
experimental_enabled: bool = True,
|
||||
) -> LoadedMessageType:
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(msg_type)
|
||||
if fields is None:
|
||||
fields = msg_type.get_fields()
|
||||
field_cache[msg_type] = fields
|
||||
|
||||
if msg_type.UNSTABLE and not experimental_enabled:
|
||||
raise ValueError # experimental messages not enabled
|
||||
|
||||
# we need to avoid calling __init__, which enforces required arguments
|
||||
msg: LoadedMessageType = object.__new__(msg_type)
|
||||
# pre-seed the object with defaults
|
||||
for fname, _, fdefault in fields.values():
|
||||
if fdefault is FLAG_REPEATED:
|
||||
fdefault = []
|
||||
elif fdefault is FLAG_EXPERIMENTAL:
|
||||
fdefault = None
|
||||
setattr(msg, fname, fdefault)
|
||||
|
||||
if False:
|
||||
SingularValue = Union[int, bool, bytearray, str, MessageType]
|
||||
Value = Union[SingularValue, list[SingularValue]]
|
||||
fvalue: Value = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
fkey = load_uvarint(reader)
|
||||
except EOFError:
|
||||
break # no more fields to load
|
||||
|
||||
ftag = fkey >> 3
|
||||
wtype = fkey & 7
|
||||
|
||||
field = fields.get(ftag, None)
|
||||
|
||||
if field is None: # unknown field, skip it
|
||||
if wtype == 0:
|
||||
load_uvarint(reader)
|
||||
elif wtype == 2:
|
||||
ivalue = load_uvarint(reader)
|
||||
reader.readinto(bytearray(ivalue))
|
||||
else:
|
||||
raise ValueError
|
||||
continue
|
||||
|
||||
fname, ftype, fdefault = field
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError # parsed wire type differs from the schema
|
||||
|
||||
if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled:
|
||||
raise ValueError # experimental fields not enabled
|
||||
|
||||
ivalue = load_uvarint(reader)
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
elif ftype is SVarintType:
|
||||
fvalue = uint_to_sint(ivalue)
|
||||
elif ftype is BoolType:
|
||||
fvalue = bool(ivalue)
|
||||
elif isinstance(ftype, EnumType):
|
||||
fvalue = ftype.validate(ivalue)
|
||||
elif ftype is BytesType:
|
||||
fvalue = bytearray(ivalue)
|
||||
reader.readinto(fvalue)
|
||||
elif ftype is UnicodeType:
|
||||
fvalue = bytearray(ivalue)
|
||||
reader.readinto(fvalue)
|
||||
fvalue = bytes(fvalue).decode()
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = load_message(
|
||||
LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled
|
||||
)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
if fdefault is FLAG_REPEATED:
|
||||
getattr(msg, fname).append(fvalue)
|
||||
else:
|
||||
setattr(msg, fname, fvalue)
|
||||
|
||||
for fname, _, _ in fields.values():
|
||||
if getattr(msg, fname) is FLAG_REQUIRED:
|
||||
# The message is intended to be user-facing when decoding from wire,
|
||||
# but not when used internally.
|
||||
raise ValueError("Required field '{}' was not received".format(fname))
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def dump_message(
|
||||
writer: Writer, msg: MessageType, field_cache: FieldCache | None = None
|
||||
) -> None:
|
||||
repvalue = [0]
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(type(msg))
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fdefault = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if fdefault is not FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
for svalue in fvalue:
|
||||
dump_uvarint(writer.write, fkey)
|
||||
|
||||
if ftype is UVarintType:
|
||||
dump_uvarint(writer.write, svalue)
|
||||
|
||||
elif ftype is SVarintType:
|
||||
dump_uvarint(writer.write, sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
dump_uvarint(writer.write, int(svalue))
|
||||
|
||||
elif isinstance(ftype, EnumType):
|
||||
dump_uvarint(writer.write, svalue)
|
||||
|
||||
elif ftype is BytesType:
|
||||
if isinstance(svalue, list):
|
||||
dump_uvarint(writer.write, _count_bytes_list(svalue))
|
||||
for sub_svalue in svalue:
|
||||
writer.write(sub_svalue)
|
||||
else:
|
||||
dump_uvarint(writer.write, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
svalue = svalue.encode()
|
||||
dump_uvarint(writer.write, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
ffields = field_cache.get(ftype)
|
||||
if ffields is None:
|
||||
ffields = ftype.get_fields()
|
||||
field_cache[ftype] = ffields
|
||||
dump_uvarint(writer.write, count_message(svalue, field_cache))
|
||||
dump_message(writer, svalue, field_cache)
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
|
||||
def count_message(msg: MessageType, field_cache: FieldCache | None = None) -> int:
|
||||
nbytes = 0
|
||||
repvalue = [0]
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(type(msg))
|
||||
if fields is None:
|
||||
fields = msg.get_fields()
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fdefault = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
continue
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if fdefault is not FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
# length of all the field keys
|
||||
nbytes += count_uvarint(fkey) * len(fvalue)
|
||||
|
||||
if ftype is UVarintType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(svalue)
|
||||
|
||||
elif ftype is SVarintType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(int(svalue))
|
||||
|
||||
elif isinstance(ftype, EnumType):
|
||||
for svalue in fvalue:
|
||||
nbytes += count_uvarint(svalue)
|
||||
|
||||
elif ftype is BytesType:
|
||||
for svalue in fvalue:
|
||||
if isinstance(svalue, list):
|
||||
svalue = _count_bytes_list(svalue)
|
||||
else:
|
||||
svalue = len(svalue)
|
||||
nbytes += count_uvarint(svalue)
|
||||
nbytes += svalue
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
for svalue in fvalue:
|
||||
svalue = len(svalue.encode())
|
||||
nbytes += count_uvarint(svalue)
|
||||
nbytes += svalue
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
for svalue in fvalue:
|
||||
fsize = count_message(svalue, field_cache)
|
||||
nbytes += count_uvarint(fsize)
|
||||
nbytes += fsize
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
return nbytes
|
||||
|
||||
|
||||
def _count_bytes_list(svalue: list[bytes]) -> int:
|
||||
res = 0
|
||||
for x in svalue:
|
||||
res += len(x)
|
||||
return res
|
42
core/src/trezor/protobuf.py
Normal file
42
core/src/trezor/protobuf.py
Normal file
@ -0,0 +1,42 @@
|
||||
import trezorproto
|
||||
|
||||
decode = trezorproto.decode
|
||||
encode = trezorproto.encode
|
||||
encoded_length = trezorproto.encoded_length
|
||||
type_for_name = trezorproto.type_for_name
|
||||
type_for_wire = trezorproto.type_for_wire
|
||||
|
||||
# XXX
|
||||
# Note that MessageType "subclasses" are not true subclasses, but instead instances
|
||||
# of the built-in metaclass MsgDef. MessageType instances are in fact instances of
|
||||
# the built-in type Msg. That is why isinstance checks do not work, and instead the
|
||||
# MessageTypeSubclass.is_type_of() method must be used.
|
||||
if False:
|
||||
from typing import Type, TypeGuard, TypeVar
|
||||
|
||||
T = TypeVar("T", bound="MessageType")
|
||||
|
||||
class MsgDef(type):
|
||||
@classmethod
|
||||
def is_type_of(cls: Type[Type[T]], msg: "MessageType") -> TypeGuard[T]:
|
||||
"""Identify if the provided message belongs to this type."""
|
||||
raise NotImplementedError
|
||||
|
||||
class MessageType(metaclass=MsgDef):
|
||||
MESSAGE_NAME: str = "MessageType"
|
||||
MESSAGE_WIRE_TYPE: int | None = None
|
||||
|
||||
|
||||
def load_message_buffer(
|
||||
buffer: bytes,
|
||||
msg_wire_type: int,
|
||||
experimental_enabled: bool = True,
|
||||
) -> MessageType:
|
||||
msg_type = type_for_wire(msg_wire_type)
|
||||
return decode(buffer, msg_type, experimental_enabled)
|
||||
|
||||
|
||||
def dump_message_buffer(msg: MessageType) -> bytearray:
|
||||
buffer = bytearray(encoded_length(msg))
|
||||
encode(buffer, msg)
|
||||
return buffer
|
@ -145,10 +145,6 @@ class HashWriter:
|
||||
def write(self, buf: bytes) -> None: # alias for extend()
|
||||
self.ctx.update(buf)
|
||||
|
||||
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
|
||||
self.ctx.update(buf)
|
||||
return len(buf)
|
||||
|
||||
def get_digest(self) -> bytes:
|
||||
return self.ctx.digest()
|
||||
|
||||
|
@ -35,11 +35,10 @@ reads the message's header. When the message type is known the first handler is
|
||||
|
||||
"""
|
||||
|
||||
import protobuf
|
||||
from storage.cache import InvalidSessionError
|
||||
from trezor import log, loop, messages, utils, workflow
|
||||
from trezor.messages import FailureType
|
||||
from trezor.messages.Failure import Failure
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.wire import codec_v1
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
|
||||
@ -74,17 +73,19 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
|
||||
|
||||
if False:
|
||||
from typing import Protocol
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
class GenericContext(Protocol):
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[protobuf.LoadedMessageType],
|
||||
expected_type: type[protobuf.MessageType],
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
async def read(self, expected_type: type[protobuf.LoadedMessageType]) -> Any:
|
||||
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
|
||||
...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
@ -96,15 +97,14 @@ if False:
|
||||
|
||||
|
||||
def _wrap_protobuf_load(
|
||||
reader: protobuf.Reader,
|
||||
expected_type: type[protobuf.LoadedMessageType],
|
||||
field_cache: protobuf.FieldCache | None = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
return protobuf.load_message(
|
||||
reader, expected_type, field_cache, experimental_enabled
|
||||
)
|
||||
return protobuf.decode(buffer, expected_type, experimental_enabled)
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: {}".format(e.args[0]))
|
||||
else:
|
||||
@ -139,19 +139,15 @@ class Context:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer = buffer
|
||||
self.buffer_writer = utils.BufferWriter(self.buffer)
|
||||
|
||||
self._field_cache: protobuf.FieldCache = {}
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[protobuf.LoadedMessageType],
|
||||
field_cache: protobuf.FieldCache | None = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
await self.write(msg, field_cache)
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read(expected_type, field_cache)
|
||||
return await self.read(expected_type)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
||||
@ -160,21 +156,17 @@ class Context:
|
||||
del msg
|
||||
return await self.read_any(expected_wire_types)
|
||||
|
||||
async def read_from_wire(self) -> codec_v1.Message:
|
||||
return await codec_v1.read_message(self.iface, self.buffer)
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_type: type[protobuf.LoadedMessageType],
|
||||
field_cache: protobuf.FieldCache | None = None,
|
||||
) -> protobuf.LoadedMessageType:
|
||||
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
@ -191,13 +183,13 @@ class Context:
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
workflow.idle_timer.touch()
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
return _wrap_protobuf_load(msg.data, expected_type, field_cache)
|
||||
return _wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def read_any(
|
||||
self, expected_wire_types: Iterable[int]
|
||||
@ -220,11 +212,15 @@ class Context:
|
||||
raise UnexpectedMessageError(msg)
|
||||
|
||||
# find the protobuf type
|
||||
exptype = messages.get_type(msg.type)
|
||||
exptype = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
exptype.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
workflow.idle_timer.touch()
|
||||
@ -232,41 +228,36 @@ class Context:
|
||||
# parse the message and return it
|
||||
return _wrap_protobuf_load(msg.data, exptype)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
field_cache: protobuf.FieldCache | None = None,
|
||||
) -> None:
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||
__name__,
|
||||
"%s:%x write: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = self._field_cache
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
# write the message
|
||||
msg_size = protobuf.count_message(msg, field_cache)
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
# prepare buffer
|
||||
if msg_size <= len(self.buffer_writer.buffer):
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer_writer = self.buffer_writer
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer_writer = utils.BufferWriter(bytearray(msg_size))
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
buffer_writer.seek(0)
|
||||
protobuf.dump_message(buffer_writer, msg, field_cache)
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer_writer.buffer)[:msg_size],
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
# make sure we don't keep around fields of all protobuf types ever
|
||||
self._field_cache.clear()
|
||||
|
||||
def wait(self, *tasks: Awaitable) -> Any:
|
||||
"""
|
||||
Wait until one of the passed tasks finishes, and return the result,
|
||||
@ -301,8 +292,8 @@ async def _handle_single_message(
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = messages.get_type(msg.type).__name__
|
||||
except KeyError:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = "%d - unknown message type" % msg.type
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -328,7 +319,7 @@ async def _handle_single_message(
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = messages.get_type(msg.type)
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
@ -424,6 +415,11 @@ async def handle_session(
|
||||
next_msg = await _handle_single_message(
|
||||
ctx, msg, use_workflow=not is_debug_session
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
if not __debug__ or not is_debug_session:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
|
@ -23,7 +23,7 @@ class CodecError(Exception):
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: utils.BufferReader) -> None:
|
||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
@ -70,7 +70,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
if read_and_throw_away:
|
||||
raise CodecError("Message too large")
|
||||
|
||||
return Message(mtype, utils.BufferReader(mdata))
|
||||
return Message(mtype, mdata)
|
||||
|
||||
|
||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||
|
@ -1,129 +0,0 @@
|
||||
from common import *
|
||||
|
||||
import protobuf
|
||||
from trezor.utils import BufferReader, BufferWriter
|
||||
|
||||
|
||||
class Message(protobuf.MessageType):
|
||||
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
|
||||
self.sint_field = sint_field
|
||||
self.enum_field = enum_field
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("sint_field", protobuf.SVarintType, 0),
|
||||
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
|
||||
}
|
||||
|
||||
|
||||
class MessageWithRequiredAndDefault(protobuf.MessageType):
|
||||
def __init__(self, required_field, default_field) -> None:
|
||||
self.required_field = required_field
|
||||
self.default_field = default_field
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||
2: ("default_field", protobuf.SVarintType, -1),
|
||||
}
|
||||
|
||||
|
||||
def load_uvarint(data: bytes) -> int:
|
||||
reader = BufferReader(data)
|
||||
return protobuf.load_uvarint(reader)
|
||||
|
||||
|
||||
def dump_uvarint(value: int) -> bytearray:
|
||||
w = bytearray()
|
||||
protobuf.dump_uvarint(w.extend, value)
|
||||
return w
|
||||
|
||||
|
||||
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||
length = protobuf.count_message(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.dump_message(BufferWriter(buffer), msg)
|
||||
return buffer
|
||||
|
||||
|
||||
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
|
||||
return protobuf.load_message(BufferReader(buffer), msg_type)
|
||||
|
||||
|
||||
class TestProtobuf(unittest.TestCase):
|
||||
def test_dump_uvarint(self):
|
||||
self.assertEqual(dump_uvarint(0), b"\x00")
|
||||
self.assertEqual(dump_uvarint(1), b"\x01")
|
||||
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
|
||||
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
|
||||
with self.assertRaises(ValueError):
|
||||
dump_uvarint(-1)
|
||||
|
||||
def test_load_uvarint(self):
|
||||
self.assertEqual(load_uvarint(b"\x00"), 0)
|
||||
self.assertEqual(load_uvarint(b"\x01"), 1)
|
||||
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
||||
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
||||
|
||||
def test_sint_uint(self):
|
||||
self.assertEqual(protobuf.uint_to_sint(0), 0)
|
||||
self.assertEqual(protobuf.sint_to_uint(0), 0)
|
||||
|
||||
self.assertEqual(protobuf.sint_to_uint(-1), 1)
|
||||
self.assertEqual(protobuf.sint_to_uint(1), 2)
|
||||
|
||||
self.assertEqual(protobuf.uint_to_sint(1), -1)
|
||||
self.assertEqual(protobuf.uint_to_sint(2), 1)
|
||||
|
||||
# roundtrip:
|
||||
self.assertEqual(
|
||||
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
||||
)
|
||||
self.assertEqual(
|
||||
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
|
||||
)
|
||||
|
||||
def test_validate_enum(self):
|
||||
# ok message:
|
||||
msg = Message(-42, 5)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(Message, msg_encoded)
|
||||
|
||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||
|
||||
# bad enum value:
|
||||
msg = Message(-42, 42)
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(TypeError):
|
||||
load_message(Message, msg_encoded)
|
||||
|
||||
def test_required(self):
|
||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.required_field, 1)
|
||||
self.assertEqual(nmsg.default_field, 2)
|
||||
|
||||
# try a message without the required_field
|
||||
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
|
||||
# encoding always succeeds
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(ValueError):
|
||||
load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
# try a message without the default field
|
||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.required_field, 1)
|
||||
self.assertEqual(nmsg.default_field, -1)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
115
core/tests/test_trezor.protobuf.py
Normal file
115
core/tests/test_trezor.protobuf.py
Normal file
@ -0,0 +1,115 @@
|
||||
from common import *
|
||||
|
||||
from trezor import protobuf
|
||||
from trezor.messages import WebAuthnCredential, EosAsset, Failure, SignMessage
|
||||
|
||||
|
||||
def load_uvarint32(data: bytes) -> int:
|
||||
# use known uint32 field in an all-optional message
|
||||
buffer = bytearray(len(data) + 1)
|
||||
buffer[1:] = data
|
||||
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
|
||||
msg = protobuf.decode(buffer, WebAuthnCredential, False)
|
||||
return msg.index
|
||||
|
||||
|
||||
def load_uvarint64(data: bytes) -> int:
|
||||
# use known uint64 field in an all-optional message
|
||||
buffer = bytearray(len(data) + 1)
|
||||
buffer[1:] = data
|
||||
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
|
||||
msg = protobuf.decode(buffer, EosAsset, False)
|
||||
return msg.symbol
|
||||
|
||||
|
||||
def dump_uvarint32(value: int) -> bytearray:
|
||||
# use known uint32 field in an all-optional message
|
||||
msg = WebAuthnCredential(index=value)
|
||||
length = protobuf.encoded_length(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.encode(buffer, msg)
|
||||
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
|
||||
return buffer[1:]
|
||||
|
||||
|
||||
def dump_uvarint64(value: int) -> bytearray:
|
||||
# use known uint64 field in an all-optional message
|
||||
msg = EosAsset(symbol=value)
|
||||
length = protobuf.encoded_length(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.encode(buffer, msg)
|
||||
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
|
||||
return buffer[1:]
|
||||
|
||||
|
||||
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||
length = protobuf.encoded_length(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.encode(buffer, msg)
|
||||
return buffer
|
||||
|
||||
|
||||
def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
|
||||
return protobuf.decode(buffer, msg_type, False)
|
||||
|
||||
|
||||
class TestProtobuf(unittest.TestCase):
|
||||
def test_dump_uvarint(self):
|
||||
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
|
||||
self.assertEqual(dump_uvarint(0), b"\x00")
|
||||
self.assertEqual(dump_uvarint(1), b"\x01")
|
||||
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
|
||||
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
|
||||
with self.assertRaises(ValueError):
|
||||
dump_uvarint(-1)
|
||||
|
||||
def test_load_uvarint(self):
|
||||
for load_uvarint in (load_uvarint32, load_uvarint64):
|
||||
self.assertEqual(load_uvarint(b"\x00"), 0)
|
||||
self.assertEqual(load_uvarint(b"\x01"), 1)
|
||||
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
||||
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
||||
|
||||
def test_validate_enum(self):
|
||||
# ok message:
|
||||
msg = Failure(code=7)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(Failure, msg_encoded)
|
||||
|
||||
self.assertEqual(msg.code, nmsg.code)
|
||||
|
||||
# bad enum value:
|
||||
msg = Failure(code=1000)
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(ValueError):
|
||||
load_message(Failure, msg_encoded)
|
||||
|
||||
def test_required(self):
|
||||
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(SignMessage, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.message, b"hello")
|
||||
self.assertEqual(nmsg.coin_name, "foo")
|
||||
self.assertEqual(nmsg.script_type, 1)
|
||||
|
||||
# try a message without the required_field
|
||||
msg = SignMessage(message=None)
|
||||
# encoding always succeeds
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(ValueError):
|
||||
load_message(SignMessage, msg_encoded)
|
||||
|
||||
# try a message without the default field
|
||||
msg = SignMessage(message=b"hello")
|
||||
msg.coin_name = None
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(SignMessage, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.message, b"hello")
|
||||
self.assertEqual(nmsg.coin_name, "Bitcoin")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -53,7 +53,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data.buffer, b"")
|
||||
self.assertEqual(result.data, b"")
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, b"\x00" * 64)
|
||||
@ -83,7 +83,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data.buffer, message)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, message)
|
||||
@ -108,7 +108,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data.buffer, message)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# read should have allocated its own buffer and not touch ours
|
||||
self.assertEqual(buffer, b"\x00")
|
||||
@ -176,7 +176,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data.buffer, message)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
def test_read_huge_packet(self):
|
||||
PACKET_COUNT = 100_000
|
||||
|
Loading…
Reference in New Issue
Block a user