mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 23:48:12 +00:00
Initial version of protobuf library
This commit is contained in:
parent
ee9b9ca351
commit
ddfde9a0ad
0
src/lib/protobuf/__init__.py
Normal file
0
src/lib/protobuf/__init__.py
Normal file
55
src/lib/protobuf/loader.py
Normal file
55
src/lib/protobuf/loader.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Describing messages themselves. ----------------------------------------------
|
||||||
|
from . import protobuf
|
||||||
|
|
||||||
|
class TypeMetadataType:
|
||||||
|
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Field description.
|
||||||
|
self.__field_metadata_type = protobuf.MessageType()
|
||||||
|
self.__field_metadata_type.add_field(1, 'tag', protobuf.UVarintType, flags=protobuf.Flags.REQUIRED)
|
||||||
|
self.__field_metadata_type.add_field(2, 'name', protobuf.BytesType, flags=protobuf.Flags.REQUIRED)
|
||||||
|
self.__field_metadata_type.add_field(3, 'type', protobuf.BytesType, flags=protobuf.Flags.REQUIRED)
|
||||||
|
self.__field_metadata_type.add_field(4, 'flags', protobuf.UVarintType, flags=protobuf.Flags.REQUIRED)
|
||||||
|
self.__field_metadata_type.add_field(5, 'embedded', protobuf.EmbeddedMessage(self)) # Used to describe embedded messages.
|
||||||
|
# Metadata message description.
|
||||||
|
self.__self_type = protobuf.EmbeddedMessage(protobuf.MessageType())
|
||||||
|
self.__self_type.message_type.add_field(1, 'fields', protobuf.EmbeddedMessage(self.__field_metadata_type), flags=(Flags.REPEATED | Flags.REQUIRED))
|
||||||
|
|
||||||
|
def __create_message(self, message_type):
|
||||||
|
# Creates a message that contains info about the message_type.
|
||||||
|
message, message.fields = self.__self_type(), list()
|
||||||
|
for field in iter(message_type):
|
||||||
|
field_meta = self.__field_metadata_type()
|
||||||
|
field_meta.tag, field_meta.name, field_type, field_meta.flags = field
|
||||||
|
field_meta.type = type_str = field_type.__class__.__name__
|
||||||
|
if isinstance(field_type, protobuf.EmbeddedMessage):
|
||||||
|
field_meta.flags |= protobuf.Flags.EMBEDDED
|
||||||
|
field_meta.embedded_metadata = self.__create_message(field_type.message_type)
|
||||||
|
elif not type_str.endswith('Type'):
|
||||||
|
raise TypeError('Type name of type singleton object should end with \'Type\'. Actual: \'%s\'.' % type_str)
|
||||||
|
else:
|
||||||
|
field_meta.type = type_str[:-4]
|
||||||
|
message.fields.append(field_meta)
|
||||||
|
return message
|
||||||
|
|
||||||
|
def dump(self, fp, message_type):
|
||||||
|
self.__self_type.dump(fp, self.__create_message(message_type))
|
||||||
|
|
||||||
|
def __restore_type(self, message):
|
||||||
|
# Restores a message type by the information in the message.
|
||||||
|
message_type, g = protobuf.MessageType(), globals()
|
||||||
|
for field in message.fields:
|
||||||
|
field_type = field['type']
|
||||||
|
if not field_type in g:
|
||||||
|
raise TypeError('Primitive type \'%s\' not found in this protobuf module.' % field_type)
|
||||||
|
field_info = (field.tag, field.name, g[field_type], field.flags)
|
||||||
|
if field.flags & protobuf.Flags.EMBEDDED_MASK == protobuf.Flags.EMBEDDED:
|
||||||
|
field_info[3] -= protobuf.Flags.EMBEDDED
|
||||||
|
field_info[2] = protobuf.EmbeddedMessage(self.__restore_type(field.embedded))
|
||||||
|
message_type.add_field(*field_info)
|
||||||
|
return message_type
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
return self.__restore_type(self.__self_type.load(fp))
|
236
src/lib/protobuf/protobuf.py
Normal file
236
src/lib/protobuf/protobuf.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Implements the Google's protobuf encoding.
|
||||||
|
# eigenein (c) 2011
|
||||||
|
# http://eigenein.me/protobuf/
|
||||||
|
|
||||||
|
from _io import BytesIO
|
||||||
|
import ustruct
|
||||||
|
|
||||||
|
# Types. -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
class UVarintType:
|
||||||
|
# Represents an unsigned Varint type.
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
shifted_value = True
|
||||||
|
while shifted_value:
|
||||||
|
shifted_value = value >> 7
|
||||||
|
fp.write(chr((value & 0x7F) | (0x80 if shifted_value != 0 else 0x00)))
|
||||||
|
value = shifted_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
value, shift, quantum = 0, 0, 0x80
|
||||||
|
while (quantum & 0x80) == 0x80:
|
||||||
|
quantum = ord(fp.read(1))
|
||||||
|
value, shift = value + ((quantum & 0x7F) << shift), shift + 7
|
||||||
|
return value
|
||||||
|
|
||||||
|
# class UInt32Type(UVarintType): pass
|
||||||
|
|
||||||
|
class BoolType:
|
||||||
|
# Represents a boolean type. Encodes True as UVarint 1, and False as UVarint 0.
|
||||||
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
fp.write('\x01' if value else '\x00')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return UVarintType.load(fp) != 0
|
||||||
|
|
||||||
|
class BytesType:
|
||||||
|
# Represents a raw bytes type.
|
||||||
|
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
UVarintType.dump(fp, len(value))
|
||||||
|
fp.write(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return fp.read(UVarintType.load(fp))
|
||||||
|
|
||||||
|
class UnicodeType:
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(fp, value):
|
||||||
|
BytesType.dump(fp, bytes(value, 'utf-8'))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(fp):
|
||||||
|
return BytesType.load(fp).decode('utf-8', 'strict')
|
||||||
|
|
||||||
|
# Messages. --------------------------------------------------------------------
|
||||||
|
|
||||||
|
FLAG_SIMPLE = 0
|
||||||
|
FLAG_REQUIRED = 1
|
||||||
|
FLAG_REQUIRED_MASK = 1
|
||||||
|
FLAG_SINGLE = 0
|
||||||
|
FLAG_REPEATED = 2
|
||||||
|
FLAG_PACKED_REPEATED = 6
|
||||||
|
FLAG_REPEATED_MASK = 6
|
||||||
|
FLAG_PRIMITIVE = 0
|
||||||
|
FLAG_EMBEDDED = 8
|
||||||
|
FLAG_EMBEDDED_MASK = 8
|
||||||
|
|
||||||
|
class EofWrapper:
|
||||||
|
# Wraps a stream to raise EOFError instead of just returning of ''.
|
||||||
|
def __init__(self, fp, limit=None):
|
||||||
|
self.__fp = fp
|
||||||
|
self.__limit = limit
|
||||||
|
|
||||||
|
def read(self, size=None):
|
||||||
|
# Reads a string. Raises EOFError on end of stream.
|
||||||
|
if self.__limit is not None:
|
||||||
|
size = min(size, self.__limit)
|
||||||
|
self.__limit -= size
|
||||||
|
s = self.__fp.read(size)
|
||||||
|
if len(s) == 0:
|
||||||
|
raise EOFError()
|
||||||
|
return s
|
||||||
|
|
||||||
|
# Packs a tag and a wire_type into single int according to the protobuf spec.
|
||||||
|
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type
|
||||||
|
# Unpacks a key into a tag and a wire_type according to the protobuf spec.
|
||||||
|
_unpack_key = lambda key: (key >> 3, key & 7)
|
||||||
|
|
||||||
|
class MessageType:
|
||||||
|
# Represents a message type.
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Creates a new message type.
|
||||||
|
self.__tags_to_types = dict() # Maps a tag to a type instance.
|
||||||
|
self.__tags_to_names = dict() # Maps a tag to a given field name.
|
||||||
|
self.__flags = dict() # Maps a tag to FLAG_
|
||||||
|
|
||||||
|
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE):
|
||||||
|
# Adds a field to the message type.
|
||||||
|
if tag in self.__tags_to_names or tag in self.__tags_to_types:
|
||||||
|
raise ValueError('The tag %s is already used.' % tag)
|
||||||
|
self.__tags_to_names[tag] = name
|
||||||
|
self.__tags_to_types[tag] = field_type
|
||||||
|
self.__flags[tag] = flags
|
||||||
|
return self # Allow add_field chaining.
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
# Creates an instance of this message type.
|
||||||
|
return Message(self)
|
||||||
|
|
||||||
|
def __has_flag(self, tag, flag, mask):
|
||||||
|
# Checks whether the field with the specified tag has the specified flag.
|
||||||
|
return (self.__flags[tag] & mask) == flag
|
||||||
|
|
||||||
|
def dump(self, fp, value):
|
||||||
|
if self != value.message_type:
|
||||||
|
raise TypeError("Incompatible type")
|
||||||
|
for tag, field_type in iter(self.__tags_to_types.items()):
|
||||||
|
if self.__tags_to_names[tag] in value.__dict__:
|
||||||
|
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
||||||
|
# Single value.
|
||||||
|
UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
|
||||||
|
field_type.dump(fp, value.__dict__[self.__tags_to_names[tag]])
|
||||||
|
elif self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated packed value.
|
||||||
|
UVarintType.dump(fp, _pack_key(tag, BytesType.WIRE_TYPE))
|
||||||
|
internal_fp = BytesIO()
|
||||||
|
for single_value in value[self.__tags_to_names[tag]]:
|
||||||
|
field_type.dump(internal_fp, single_value)
|
||||||
|
BytesType.dump(fp, internal_fp.getvalue())
|
||||||
|
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated value.
|
||||||
|
key = _pack_key(tag, field_type.WIRE_TYPE)
|
||||||
|
# Put it together sequently.
|
||||||
|
for single_value in value[self.__tags_to_names[tag]]:
|
||||||
|
UVarintType.dump(fp, key)
|
||||||
|
field_type.dump(fp, single_value)
|
||||||
|
elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK):
|
||||||
|
raise ValueError('The field with the tag %s is required but a value is missing.' % tag)
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
fp, message = EofWrapper(fp), self.__call__() # Wrap fp and create a new instance.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tag, wire_type = _unpack_key(UVarintType.load(fp))
|
||||||
|
|
||||||
|
if tag in self.__tags_to_types:
|
||||||
|
field_type = self.__tags_to_types[tag]
|
||||||
|
if not self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
if wire_type != field_type.WIRE_TYPE:
|
||||||
|
raise TypeError(
|
||||||
|
'Value of tag %s has incorrect wiretype %s, %s expected.' % \
|
||||||
|
(tag, wire_type, field_type.WIRE_TYPE))
|
||||||
|
elif wire_type != BytesType.WIRE_TYPE:
|
||||||
|
raise TypeError('Tag %s has wiretype %s while the field is packed repeated.' % (tag, wire_type))
|
||||||
|
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
|
||||||
|
# Single value.
|
||||||
|
setattr(message, self.__tags_to_names[tag], field_type.load(fp))
|
||||||
|
elif self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated packed value.
|
||||||
|
repeated_value = message[self.__tags_to_names[tag]] = list()
|
||||||
|
internal_fp = EofWrapper(fp, UVarintType.load(fp)) # Limit with value length.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
repeated_value.append(field_type.load(internal_fp))
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
# Repeated value.
|
||||||
|
if not self.__tags_to_names[tag] in message:
|
||||||
|
repeated_value = message[self.__tags_to_names[tag]] = list()
|
||||||
|
repeated_value.append(field_type.load(fp))
|
||||||
|
else:
|
||||||
|
# Skip this field.
|
||||||
|
|
||||||
|
# This used to correctly determine the length of unknown tags when loading a message.
|
||||||
|
{0: UVarintType, 2: BytesType}[wire_type].load(fp)
|
||||||
|
|
||||||
|
except EOFError:
|
||||||
|
# Check if all required fields are present.
|
||||||
|
for tag, name in iter(self.__tags_to_names.items()):
|
||||||
|
if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message:
|
||||||
|
if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
|
||||||
|
message[name] = list() # Empty list (no values was in input stream). But required field.
|
||||||
|
else:
|
||||||
|
raise ValueError('The field %s (\'%s\') is required but missing.' % (tag, name))
|
||||||
|
return message
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
# Represents a message instance.
|
||||||
|
|
||||||
|
def __init__(self, message_type):
|
||||||
|
# Initializes a new instance of the specified message type.
|
||||||
|
self.message_type = message_type
|
||||||
|
|
||||||
|
def dump(self, fp):
|
||||||
|
# Dumps the message into a write-like object.
|
||||||
|
return self.message_type.dump(fp, self)
|
||||||
|
|
||||||
|
# Embedded message. ------------------------------------------------------------
|
||||||
|
|
||||||
|
class EmbeddedMessage:
|
||||||
|
# Represents an embedded message type.
|
||||||
|
|
||||||
|
WIRE_TYPE = 2
|
||||||
|
|
||||||
|
def __init__(self, message_type):
|
||||||
|
# Initializes a new instance. The argument is an underlying message type.
|
||||||
|
self.message_type = message_type
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
# Creates a message of the underlying message type.
|
||||||
|
return self.message_type()
|
||||||
|
|
||||||
|
def dump(self, fp, value):
|
||||||
|
BytesType.dump(fp, self.message_type.dumps(value))
|
||||||
|
|
||||||
|
def load(self, fp):
|
||||||
|
return self.message_type.load(EofWrapper(fp, UVarintType.load(fp))) # Limit with embedded message length.
|
Loading…
Reference in New Issue
Block a user