2019-08-02 17:06:01 +00:00
|
|
|
from common import *
|
|
|
|
|
|
|
|
import protobuf
|
2020-07-13 13:04:30 +00:00
|
|
|
from trezor.utils import BufferReader, BufferWriter
|
2019-08-02 17:06:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Message(protobuf.MessageType):
|
2020-09-14 10:47:06 +00:00
|
|
|
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
|
|
|
|
self.sint_field = sint_field
|
2019-08-02 17:06:01 +00:00
|
|
|
self.enum_field = enum_field
|
|
|
|
|
|
|
|
@classmethod
|
2020-07-13 10:59:10 +00:00
|
|
|
def get_fields(cls):
|
2019-08-02 17:06:01 +00:00
|
|
|
return {
|
|
|
|
1: ("sint_field", protobuf.SVarintType, 0),
|
|
|
|
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-09-14 10:47:06 +00:00
|
|
|
class MessageWithRequiredAndDefault(protobuf.MessageType):
|
|
|
|
def __init__(self, required_field, default_field) -> None:
|
|
|
|
self.required_field = required_field
|
|
|
|
self.default_field = default_field
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_fields(cls):
|
|
|
|
return {
|
|
|
|
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
|
|
|
2: ("default_field", protobuf.SVarintType, -1),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-08-02 17:06:01 +00:00
|
|
|
def load_uvarint(data: bytes) -> int:
|
2020-07-13 13:04:30 +00:00
|
|
|
reader = BufferReader(data)
|
2020-07-13 10:59:10 +00:00
|
|
|
return protobuf.load_uvarint(reader)
|
2019-08-02 17:06:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
def dump_uvarint(value: int) -> bytearray:
|
2020-07-13 13:04:30 +00:00
|
|
|
writer = BufferWriter(bytearray(16))
|
2020-07-13 10:59:10 +00:00
|
|
|
protobuf.dump_uvarint(writer, value)
|
2020-09-14 10:47:06 +00:00
|
|
|
return memoryview(writer.buffer)[: writer.offset]
|
|
|
|
|
|
|
|
|
|
|
|
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
|
|
|
length = protobuf.count_message(msg)
|
|
|
|
buffer = bytearray(length)
|
|
|
|
protobuf.dump_message(BufferWriter(buffer), msg)
|
|
|
|
return buffer
|
|
|
|
|
|
|
|
|
|
|
|
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
|
|
|
|
return protobuf.load_message(BufferReader(buffer), msg_type)
|
2019-08-02 17:06:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestProtobuf(unittest.TestCase):
|
|
|
|
def test_dump_uvarint(self):
|
|
|
|
self.assertEqual(dump_uvarint(0), b"\x00")
|
|
|
|
self.assertEqual(dump_uvarint(1), b"\x01")
|
|
|
|
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
|
|
|
|
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
dump_uvarint(-1)
|
|
|
|
|
|
|
|
def test_load_uvarint(self):
|
|
|
|
self.assertEqual(load_uvarint(b"\x00"), 0)
|
|
|
|
self.assertEqual(load_uvarint(b"\x01"), 1)
|
|
|
|
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
|
|
|
|
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
|
|
|
|
|
|
|
|
def test_sint_uint(self):
|
|
|
|
self.assertEqual(protobuf.uint_to_sint(0), 0)
|
|
|
|
self.assertEqual(protobuf.sint_to_uint(0), 0)
|
|
|
|
|
|
|
|
self.assertEqual(protobuf.sint_to_uint(-1), 1)
|
|
|
|
self.assertEqual(protobuf.sint_to_uint(1), 2)
|
|
|
|
|
|
|
|
self.assertEqual(protobuf.uint_to_sint(1), -1)
|
|
|
|
self.assertEqual(protobuf.uint_to_sint(2), 1)
|
|
|
|
|
|
|
|
# roundtrip:
|
|
|
|
self.assertEqual(
|
|
|
|
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
2020-09-14 10:47:06 +00:00
|
|
|
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
|
2019-08-02 17:06:01 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def test_validate_enum(self):
|
|
|
|
# ok message:
|
|
|
|
msg = Message(-42, 5)
|
2020-09-14 10:47:06 +00:00
|
|
|
msg_encoded = dump_message(msg)
|
|
|
|
nmsg = load_message(Message, msg_encoded)
|
2019-08-02 17:06:01 +00:00
|
|
|
|
|
|
|
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
|
|
|
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
|
|
|
|
|
|
|
# bad enum value:
|
|
|
|
msg = Message(-42, 42)
|
2020-09-14 10:47:06 +00:00
|
|
|
msg_encoded = dump_message(msg)
|
2019-08-02 17:06:01 +00:00
|
|
|
with self.assertRaises(TypeError):
|
2020-09-14 10:47:06 +00:00
|
|
|
load_message(Message, msg_encoded)
|
|
|
|
|
|
|
|
def test_required(self):
|
|
|
|
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
|
|
|
|
msg_encoded = dump_message(msg)
|
|
|
|
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
|
|
|
|
|
|
self.assertEqual(nmsg.required_field, 1)
|
|
|
|
self.assertEqual(nmsg.default_field, 2)
|
|
|
|
|
|
|
|
# try a message without the required_field
|
|
|
|
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
|
|
|
|
# encoding always succeeds
|
|
|
|
msg_encoded = dump_message(msg)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
|
|
|
|
|
|
# try a message without the default field
|
|
|
|
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
|
|
|
|
msg_encoded = dump_message(msg)
|
|
|
|
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
|
|
|
|
|
|
|
self.assertEqual(nmsg.required_field, 1)
|
|
|
|
self.assertEqual(nmsg.default_field, -1)
|
|
|
|
|
2019-08-02 17:06:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|