1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-15 20:19:23 +00:00
trezor-firmware/core/src/protobuf.py

436 lines
12 KiB
Python
Raw Normal View History

2019-07-03 13:07:04 +00:00
"""
2017-07-04 16:09:08 +00:00
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
2019-07-03 13:07:04 +00:00
"""
2017-08-21 11:22:35 +00:00
2019-07-03 13:07:04 +00:00
if False:
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
2019-07-03 13:07:04 +00:00
from typing_extensions import Protocol
2017-08-21 11:22:35 +00:00
2020-06-26 10:30:12 +00:00
class Reader(Protocol):
def readinto(self, buf: bytearray) -> int:
2019-07-03 13:07:04 +00:00
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
2017-08-21 11:22:35 +00:00
2020-06-26 10:30:12 +00:00
class Writer(Protocol):
def write(self, buf: bytes) -> int:
2019-07-03 13:07:04 +00:00
"""
Writes all bytes from `buf`, or raises `EOFError`.
"""
2016-09-21 12:14:49 +00:00
2016-09-29 10:29:43 +00:00
2017-07-04 16:09:08 +00:00
_UVARINT_BUFFER = bytearray(1)
2016-09-21 12:14:49 +00:00
2020-06-26 10:30:12 +00:00
def load_uvarint(reader: Reader) -> int:
2017-07-04 16:09:08 +00:00
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
2020-06-26 10:30:12 +00:00
reader.readinto(buffer)
2017-07-04 16:09:08 +00:00
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
2016-09-21 12:14:49 +00:00
2020-06-26 10:30:12 +00:00
def dump_uvarint(writer: Writer, n: int) -> None:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
2017-07-04 16:09:08 +00:00
buffer = _UVARINT_BUFFER
2019-07-03 13:07:04 +00:00
shifted = 1
2017-07-04 16:09:08 +00:00
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
2020-06-26 10:30:12 +00:00
writer.write(buffer)
2017-07-04 16:09:08 +00:00
n = shifted
2019-07-03 13:07:04 +00:00
def count_uvarint(n: int) -> int:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
2018-10-26 13:14:23 +00:00
if n <= 0x7F:
return 1
2018-10-26 13:14:23 +00:00
if n <= 0x3FFF:
return 2
if n <= 0x1F_FFFF:
return 3
if n <= 0xFFF_FFFF:
return 4
if n <= 0x7_FFFF_FFFF:
return 5
if n <= 0x3FF_FFFF_FFFF:
return 6
if n <= 0x1_FFFF_FFFF_FFFF:
return 7
if n <= 0xFF_FFFF_FFFF_FFFF:
return 8
if n <= 0x7FFF_FFFF_FFFF_FFFF:
return 9
raise ValueError
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
2018-07-03 14:20:58 +00:00
2019-07-03 13:07:04 +00:00
def sint_to_uint(sint: int) -> int:
res = sint << 1
if sint < 0:
res = ~res
return res
2019-07-03 13:07:04 +00:00
def uint_to_sint(uint: int) -> int:
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
2016-04-07 21:45:10 +00:00
class UVarintType:
WIRE_TYPE = 0
class SVarintType:
WIRE_TYPE = 0
2017-07-04 16:09:08 +00:00
class BoolType:
WIRE_TYPE = 0
2016-08-05 10:35:45 +00:00
2016-04-07 21:45:10 +00:00
class EnumType:
WIRE_TYPE = 0
2019-08-06 13:42:23 +00:00
def __init__(self, name: str, enum_values: Iterable[int]) -> None:
self.enum_values = enum_values
def validate(self, fvalue: int) -> int:
if fvalue in self.enum_values:
return fvalue
else:
raise TypeError("Invalid enum value")
2017-07-04 16:09:08 +00:00
class BytesType:
WIRE_TYPE = 2
2016-04-07 21:45:10 +00:00
2017-07-04 16:09:08 +00:00
class UnicodeType:
2016-04-07 21:45:10 +00:00
WIRE_TYPE = 2
2016-09-21 12:14:49 +00:00
2017-07-04 16:09:08 +00:00
class MessageType:
WIRE_TYPE = 2
UNSTABLE = False
2019-07-03 13:07:04 +00:00
# Type id for the wire codec.
# Technically, not every protobuf message has this.
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls) -> "FieldDict":
return {}
2019-07-03 13:07:04 +00:00
def __eq__(self, rhs: Any) -> bool:
2018-07-03 14:20:58 +00:00
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
2019-07-03 13:07:04 +00:00
def __repr__(self) -> str:
2018-07-03 14:20:58 +00:00
return "<%s>" % self.__class__.__name__
2017-07-04 16:09:08 +00:00
class LimitedReader:
2020-06-26 10:30:12 +00:00
def __init__(self, reader: Reader, limit: int) -> None:
2017-07-04 16:09:08 +00:00
self.reader = reader
self.limit = limit
2020-06-26 10:30:12 +00:00
def readinto(self, buf: bytearray) -> int:
2017-07-04 16:09:08 +00:00
if self.limit < len(buf):
raise EOFError
else:
2020-06-26 10:30:12 +00:00
nread = self.reader.readinto(buf)
2017-07-04 16:09:08 +00:00
self.limit -= nread
return nread
FLAG_REPEATED = object()
FLAG_REQUIRED = object()
FLAG_EXPERIMENTAL = object()
2017-07-04 16:09:08 +00:00
2019-07-03 13:07:04 +00:00
if False:
MessageTypeDef = Union[
Type[UVarintType],
Type[SVarintType],
Type[BoolType],
EnumType,
Type[BytesType],
Type[UnicodeType],
Type[MessageType],
]
FieldDef = Tuple[str, MessageTypeDef, Any]
FieldDict = Dict[int, FieldDef]
FieldCache = Dict[Type[MessageType], FieldDict]
2019-07-03 13:07:04 +00:00
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
2017-07-04 16:09:08 +00:00
2020-06-26 10:30:12 +00:00
def load_message(
reader: Reader,
msg_type: Type[LoadedMessageType],
field_cache: FieldCache = None,
experimental_enabled: bool = True,
2019-07-03 13:07:04 +00:00
) -> LoadedMessageType:
2020-06-26 11:55:17 +00:00
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
if fields is None:
fields = msg_type.get_fields()
field_cache[msg_type] = fields
if msg_type.UNSTABLE and not experimental_enabled:
raise ValueError # experimental messages not enabled
# we need to avoid calling __init__, which enforces required arguments
msg: LoadedMessageType = object.__new__(msg_type)
# pre-seed the object with defaults
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REPEATED:
fdefault = []
elif fdefault is FLAG_EXPERIMENTAL:
fdefault = None
setattr(msg, fname, fdefault)
2017-07-04 16:09:08 +00:00
2019-08-06 13:42:23 +00:00
if False:
SingularValue = Union[int, bool, bytearray, str, MessageType]
Value = Union[SingularValue, List[SingularValue]]
fvalue: Value = 0
2019-08-06 13:42:23 +00:00
2017-07-04 16:09:08 +00:00
while True:
2016-09-21 12:14:49 +00:00
try:
2020-06-26 10:30:12 +00:00
fkey = load_uvarint(reader)
2017-07-04 16:09:08 +00:00
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
2020-06-26 10:30:12 +00:00
load_uvarint(reader)
2017-07-04 16:09:08 +00:00
elif wtype == 2:
2020-06-26 10:30:12 +00:00
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
2017-07-04 16:09:08 +00:00
else:
raise ValueError
continue
fname, ftype, fdefault = field
2017-07-04 16:09:08 +00:00
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
if fdefault is FLAG_EXPERIMENTAL and not experimental_enabled:
raise ValueError # experimental fields not enabled
2020-06-26 10:30:12 +00:00
ivalue = load_uvarint(reader)
2017-07-04 16:09:08 +00:00
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
2017-07-04 16:09:08 +00:00
elif ftype is BoolType:
fvalue = bool(ivalue)
elif isinstance(ftype, EnumType):
fvalue = ftype.validate(ivalue)
2017-07-04 16:09:08 +00:00
elif ftype is BytesType:
fvalue = bytearray(ivalue)
2020-06-26 10:30:12 +00:00
reader.readinto(fvalue)
2017-07-04 16:09:08 +00:00
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
2020-06-26 10:30:12 +00:00
reader.readinto(fvalue)
2018-08-03 16:52:20 +00:00
fvalue = bytes(fvalue).decode()
2017-07-04 16:09:08 +00:00
elif issubclass(ftype, MessageType):
fvalue = load_message(
LimitedReader(reader, ivalue), ftype, field_cache, experimental_enabled
)
2017-07-04 16:09:08 +00:00
else:
raise TypeError # field type is unknown
if fdefault is FLAG_REPEATED:
getattr(msg, fname).append(fvalue)
else:
setattr(msg, fname, fvalue)
2017-07-04 16:09:08 +00:00
for fname, _, _ in fields.values():
if getattr(msg, fname) is FLAG_REQUIRED:
# The message is intended to be user-facing when decoding from wire,
# but not when used internally.
raise ValueError("Required field '{}' was not received".format(fname))
2017-07-04 16:09:08 +00:00
return msg
def dump_message(
writer: Writer, msg: MessageType, field_cache: FieldCache = None
) -> None:
2017-07-04 16:09:08 +00:00
repvalue = [0]
2020-06-26 11:55:17 +00:00
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
2020-06-26 11:55:17 +00:00
field_cache[type(msg)] = fields
2017-07-04 16:09:08 +00:00
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
2017-07-04 16:09:08 +00:00
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
2017-07-04 16:09:08 +00:00
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, fkey)
2017-07-04 16:09:08 +00:00
if ftype is UVarintType:
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, svalue)
2017-07-04 16:09:08 +00:00
elif ftype is SVarintType:
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, sint_to_uint(svalue))
2017-07-04 16:09:08 +00:00
elif ftype is BoolType:
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, int(svalue))
2017-07-04 16:09:08 +00:00
elif isinstance(ftype, EnumType):
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, svalue)
2017-07-04 16:09:08 +00:00
elif ftype is BytesType:
if isinstance(svalue, list):
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, _count_bytes_list(svalue))
for sub_svalue in svalue:
2020-06-26 10:30:12 +00:00
writer.write(sub_svalue)
else:
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, len(svalue))
writer.write(svalue)
2017-07-04 16:09:08 +00:00
elif ftype is UnicodeType:
svalue = svalue.encode()
2020-06-26 10:30:12 +00:00
dump_uvarint(writer, len(svalue))
writer.write(svalue)
2017-07-04 16:09:08 +00:00
elif issubclass(ftype, MessageType):
2020-06-26 11:55:17 +00:00
ffields = field_cache.get(ftype)
2019-08-07 08:58:03 +00:00
if ffields is None:
ffields = ftype.get_fields()
2020-06-26 11:55:17 +00:00
field_cache[ftype] = ffields
dump_uvarint(writer, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache)
2017-07-04 16:09:08 +00:00
else:
2017-07-04 16:09:08 +00:00
raise TypeError
def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
nbytes = 0
repvalue = [0]
2020-06-26 11:55:17 +00:00
if field_cache is None:
field_cache = {}
fields = field_cache.get(type(msg))
if fields is None:
fields = msg.get_fields()
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
# length of all the field keys
nbytes += count_uvarint(fkey) * len(fvalue)
if ftype is UVarintType:
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is SVarintType:
for svalue in fvalue:
nbytes += count_uvarint(sint_to_uint(svalue))
elif ftype is BoolType:
for svalue in fvalue:
nbytes += count_uvarint(int(svalue))
elif isinstance(ftype, EnumType):
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is BytesType:
for svalue in fvalue:
if isinstance(svalue, list):
svalue = _count_bytes_list(svalue)
else:
svalue = len(svalue)
nbytes += count_uvarint(svalue)
nbytes += svalue
elif ftype is UnicodeType:
for svalue in fvalue:
svalue = len(svalue.encode())
nbytes += count_uvarint(svalue)
nbytes += svalue
elif issubclass(ftype, MessageType):
for svalue in fvalue:
fsize = count_message(svalue, field_cache)
nbytes += count_uvarint(fsize)
nbytes += fsize
else:
raise TypeError
return nbytes
2019-07-03 13:07:04 +00:00
def _count_bytes_list(svalue: List[bytes]) -> int:
res = 0
for x in svalue:
res += len(x)
return res