1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-12 00:10:58 +00:00
trezor-firmware/core/src/protobuf.py
2019-04-15 19:14:40 +02:00

369 lines
9.5 KiB
Python

'''
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
For de-serializing (loading) protobuf types, object with `AsyncReader`
interface is required:
>>> class AsyncReader:
>>> async def areadinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
For serializing (dumping) protobuf types, object with `AsyncWriter` interface is
required:
>>> class AsyncWriter:
>>> async def awrite(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
'''
from micropython import const
_UVARINT_BUFFER = bytearray(1)
async def load_uvarint(reader):
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
await reader.areadinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
async def dump_uvarint(writer, n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.awrite(buffer)
n = shifted
def count_uvarint(n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7F:
return 1
if n <= 0x3FFF:
return 2
if n <= 0x1FFFFF:
return 3
if n <= 0xFFFFFFF:
return 4
if n <= 0x7FFFFFFFF:
return 5
if n <= 0x3FFFFFFFFFF:
return 6
if n <= 0x1FFFFFFFFFFFF:
return 7
if n <= 0xFFFFFFFFFFFFFF:
return 8
if n <= 0x7FFFFFFFFFFFFFFF:
return 9
raise ValueError
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint):
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint):
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType:
WIRE_TYPE = 0
class SVarintType:
WIRE_TYPE = 0
class BoolType:
WIRE_TYPE = 0
class BytesType:
WIRE_TYPE = 2
class UnicodeType:
WIRE_TYPE = 2
class MessageType:
WIRE_TYPE = 2
@classmethod
def get_fields(cls):
return {}
def __init__(self, **kwargs):
for kw in kwargs:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs):
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self):
return "<%s>" % self.__class__.__name__
class LimitedReader:
def __init__(self, reader, limit):
self.reader = reader
self.limit = limit
async def areadinto(self, buf):
if self.limit < len(buf):
raise EOFError
else:
nread = await self.reader.areadinto(buf)
self.limit -= nread
return nread
class CountingWriter:
def __init__(self):
self.size = 0
async def awrite(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1)
async def load_message(reader, msg_type):
fields = msg_type.get_fields()
msg = msg_type()
while True:
try:
fkey = await load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
await load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
await reader.areadinto(bytearray(ivalue))
else:
raise ValueError
continue
fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
await reader.areadinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.areadinto(fvalue)
fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname, [])
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
# fill missing fields
for tag in fields:
field = fields[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
return msg
async def dump_message(writer, msg, fields=None):
repvalue = [0]
if fields is None:
fields = msg.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
if issubclass(ftype, MessageType):
ffields = ftype.get_fields()
else:
ffields = None
for svalue in fvalue:
await dump_uvarint(writer, fkey)
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
elif ftype is SVarintType:
await dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
if isinstance(svalue, list):
await dump_uvarint(writer, _count_bytes_list(svalue))
for sub_svalue in svalue:
await writer.awrite(sub_svalue)
else:
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
elif ftype is UnicodeType:
svalue = svalue.encode()
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
elif issubclass(ftype, MessageType):
await dump_uvarint(writer, count_message(svalue, ffields))
await dump_message(writer, svalue, ffields)
else:
raise TypeError
def count_message(msg, fields=None):
nbytes = 0
repvalue = [0]
if fields is None:
fields = msg.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
# length of all the field keys
nbytes += count_uvarint(fkey) * len(fvalue)
if ftype is UVarintType:
for svalue in fvalue:
nbytes += count_uvarint(svalue)
elif ftype is SVarintType:
for svalue in fvalue:
nbytes += count_uvarint(sint_to_uint(svalue))
elif ftype is BoolType:
for svalue in fvalue:
nbytes += count_uvarint(int(svalue))
elif ftype is BytesType:
for svalue in fvalue:
if isinstance(svalue, list):
svalue = _count_bytes_list(svalue)
else:
svalue = len(svalue)
nbytes += count_uvarint(svalue)
nbytes += svalue
elif ftype is UnicodeType:
for svalue in fvalue:
svalue = len(svalue.encode())
nbytes += count_uvarint(svalue)
nbytes += svalue
elif issubclass(ftype, MessageType):
ffields = ftype.get_fields()
for svalue in fvalue:
fsize = count_message(svalue, ffields)
nbytes += count_uvarint(fsize)
nbytes += fsize
del ffields
else:
raise TypeError
return nbytes
def _count_bytes_list(svalue):
res = 0
for x in svalue:
res += len(x)
return res