1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-27 09:58:27 +00:00
trezor-firmware/trezorlib/protobuf.py
matejcik afb3e04c24 trezorlib/protobuf.py: return BytesType from wire as bytes, not bytearray.
This makes more sense, because bytes are immutable and callers have no business
mutating structures from the wire anyway.

Incidentally this should fix issue #236, where rlp library would treat
bytes and bytearrays differently and produce invalid structures in our usecase.

Also very minor nitpicks and code cleanup for neater typing.
2018-03-20 13:00:36 +01:00

320 lines
8.1 KiB
Python

# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
'''
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
For de-sererializing (loading) protobuf types, object with `Reader`
interface is required:
>>> class Reader:
>>> def readinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
For serializing (dumping) protobuf types, object with `Writer` interface is
required:
>>> class Writer:
>>> def write(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
'''
from io import BytesIO
_UVARINT_BUFFER = bytearray(1)
def load_uvarint(reader):
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
if reader.readinto(buffer) == 0:
raise EOFError
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
def dump_uvarint(writer, n):
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
writer.write(buffer)
n = shifted
class UVarintType:
WIRE_TYPE = 0
class Sint32Type:
WIRE_TYPE = 0
class Sint64Type:
WIRE_TYPE = 0
class BoolType:
WIRE_TYPE = 0
class BytesType:
WIRE_TYPE = 2
class UnicodeType:
WIRE_TYPE = 2
class MessageType:
WIRE_TYPE = 2
FIELDS = {}
def __init__(self, **kwargs):
for kw in kwargs:
setattr(self, kw, kwargs[kw])
self._fill_missing()
def __eq__(self, rhs):
return (self.__class__ is rhs.__class__ and
self.__dict__ == rhs.__dict__)
def __repr__(self):
d = {}
for key, value in self.__dict__.items():
if value is None or value == []:
continue
d[key] = value
return '<%s: %s>' % (self.__class__.__name__, d)
def __iter__(self):
return self.__dict__.__iter__()
def __getattr__(self, attr):
if attr.startswith('_add_'):
return self._additem(attr[5:])
if attr.startswith('_extend_'):
return self._extenditem(attr[8:])
raise AttributeError(attr)
def _extenditem(self, attr):
def f(param):
try:
l = getattr(self, attr)
except AttributeError:
l = []
setattr(self, attr, l)
l += param
return f
def _additem(self, attr):
# Add new item for repeated field type
for v in self.FIELDS.values():
if v[0] != attr:
continue
if not (v[2] & FLAG_REPEATED):
raise AttributeError
try:
l = getattr(self, v[0])
except AttributeError:
l = []
setattr(self, v[0], l)
item = v[1]()
l.append(item)
return lambda: item
raise AttributeError
def _fill_missing(self):
# fill missing fields
for fname, ftype, fflags in self.FIELDS.values():
if not hasattr(self, fname):
if fflags & FLAG_REPEATED:
setattr(self, fname, [])
else:
setattr(self, fname, None)
def CopyFrom(self, obj):
self.__dict__ = obj.__dict__.copy()
def ByteSize(self):
data = BytesIO()
dump_message(data, self)
return len(data.getvalue())
class LimitedReader:
def __init__(self, reader, limit):
self.reader = reader
self.limit = limit
def readinto(self, buf):
if self.limit < len(buf):
raise EOFError
else:
nread = self.reader.readinto(buf)
self.limit -= nread
return nread
class CountingWriter:
def __init__(self):
self.size = 0
def write(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = 1
def load_message(reader, msg_type):
fields = msg_type.FIELDS
msg = msg_type()
while True:
try:
fkey = load_uvarint(reader)
except EOFError:
break # no more fields to load
ftag = fkey >> 3
wtype = fkey & 7
field = fields.get(ftag, None)
if field is None: # unknown field, skip it
if wtype == 0:
load_uvarint(reader)
elif wtype == 2:
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
else:
raise ValueError
continue
fname, ftype, fflags = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
elif ftype is Sint32Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
elif ftype is Sint64Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = bytes(buf)
elif ftype is UnicodeType:
buf = bytearray(ivalue)
reader.readinto(buf)
fvalue = buf.decode()
elif issubclass(ftype, MessageType):
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname)
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
return msg
def dump_message(writer, msg):
repvalue = [0]
mtype = msg.__class__
fields = mtype.FIELDS
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
continue
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
for svalue in fvalue:
dump_uvarint(writer, fkey)
if ftype is UVarintType:
dump_uvarint(writer, svalue)
elif ftype is Sint32Type:
dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
elif ftype is Sint64Type:
dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
elif ftype is BoolType:
dump_uvarint(writer, int(svalue))
elif ftype is BytesType:
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif ftype is UnicodeType:
if not isinstance(svalue, bytes):
svalue = svalue.encode('utf-8')
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType):
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
dump_message(writer, svalue)
else:
raise TypeError