1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 05:51:38 +00:00
trezor-firmware/core/tests/test_trezor.protobuf.py

123 lines
3.9 KiB
Python
Raw Normal View History

# flake8: noqa: F403,F405
from common import * # isort:skip
from trezor import protobuf
2023-06-28 10:46:29 +00:00
from trezor.messages import (
DebugLinkMemoryRead,
2023-06-28 10:46:29 +00:00
Failure,
SignMessage,
WebAuthnCredential,
2023-06-28 10:46:29 +00:00
)
def load_uvarint32(data: bytes) -> int:
# use known uint32 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, WebAuthnCredential, False)
return msg.index
def load_uvarint64(data: bytes) -> int:
# use known uint64 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, DebugLinkMemoryRead, False)
return msg.length
def dump_uvarint32(value: int) -> bytearray:
# use known uint32 field in an all-optional message
msg = WebAuthnCredential(index=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
return buffer[1:]
def dump_uvarint64(value: int) -> bytearray:
# use known uint64 field in an all-optional message
msg = DebugLinkMemoryRead(length=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
return buffer[1:]
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
return buffer
2023-06-28 10:46:29 +00:00
def load_message(
msg_type: Type[protobuf.MessageType], buffer: bytes
) -> protobuf.MessageType:
return protobuf.decode(buffer, msg_type, False)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(OverflowError):
dump_uvarint(-1)
def test_load_uvarint(self):
for load_uvarint in (load_uvarint32, load_uvarint64):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_validate_enum(self):
# ok message:
msg = Failure(code=7)
msg_encoded = dump_message(msg)
nmsg = load_message(Failure, msg_encoded)
self.assertEqual(msg.code, nmsg.code)
# bad enum value:
msg = Failure(code=1000)
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(Failure, msg_encoded)
def test_required(self):
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "foo")
self.assertEqual(nmsg.script_type, 1)
# try a message without the required_field
msg = SignMessage(message=None)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(SignMessage, msg_encoded)
# try a message without the default field
msg = SignMessage(message=b"hello")
msg.coin_name = None
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "Bitcoin")
if __name__ == "__main__":
unittest.main()