You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/tests/test_trezor.protobuf.py

132 lines
4.3 KiB

from common import * # isort:skip
from trezor import protobuf
from trezor.messages import (
PrevOutput,
Failure,
SignMessage,
WebAuthnCredential,
)
def load_uvarint32(data: bytes) -> int:
# use known uint32 field in an all-optional message
buffer = bytearray(len(data) + 1)
buffer[1:] = data
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
msg = protobuf.decode(buffer, WebAuthnCredential, False)
return msg.index
def load_uvarint64(data: bytes) -> int:
# use known uint64 field
# message PrevOutput {
# required uint64 amount = 1;
# required bytes script_pubkey = 2;
# optional uint32 decred_script_version = 3;
# }
buffer = bytearray(len(data) + 1 + 2)
buffer[1:-2] = data
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
# create a zero-length script-pubkey field
buffer[-2] = (2 << 3) | 2 # field number 2, wire type 2
buffer[-1] = 0 # length of the script-pubkey
msg = protobuf.decode(buffer, PrevOutput, False)
return msg.amount
def dump_uvarint32(value: int) -> bytearray:
# use known uint32 field in an all-optional message
msg = WebAuthnCredential(index=value)
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
return buffer[1:]
def dump_uvarint64(value: int) -> bytearray:
# use known uint64 field in an all-optional message
msg = PrevOutput(amount=value, script_pubkey=b"")
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
assert buffer[-2] == (2 << 3) | 2 # field number 2, wire type 2
assert buffer[-1] == 0 # length of the script-pubkey
return buffer[1:-2]
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.encoded_length(msg)
buffer = bytearray(length)
protobuf.encode(buffer, msg)
return buffer
def load_message(
msg_type: Type[protobuf.MessageType], buffer: bytes
) -> protobuf.MessageType:
return protobuf.decode(buffer, msg_type, False)
class TestProtobuf(unittest.TestCase):
def test_dump_uvarint(self):
for dump_uvarint in (dump_uvarint32, dump_uvarint64):
self.assertEqual(dump_uvarint(0), b"\x00")
self.assertEqual(dump_uvarint(1), b"\x01")
self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
with self.assertRaises(OverflowError):
dump_uvarint(-1)
def test_load_uvarint(self):
for load_uvarint in (load_uvarint32, load_uvarint64):
self.assertEqual(load_uvarint(b"\x00"), 0)
self.assertEqual(load_uvarint(b"\x01"), 1)
self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)
def test_validate_enum(self):
# ok message:
msg = Failure(code=7)
msg_encoded = dump_message(msg)
nmsg = load_message(Failure, msg_encoded)
self.assertEqual(msg.code, nmsg.code)
# bad enum value:
msg = Failure(code=1000)
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(Failure, msg_encoded)
def test_required(self):
msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "foo")
self.assertEqual(nmsg.script_type, 1)
# try a message without the required_field
msg = SignMessage(message=None)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(SignMessage, msg_encoded)
# try a message without the default field
msg = SignMessage(message=b"hello")
msg.coin_name = None
msg_encoded = dump_message(msg)
nmsg = load_message(SignMessage, msg_encoded)
self.assertEqual(nmsg.message, b"hello")
self.assertEqual(nmsg.coin_name, "Bitcoin")
if __name__ == "__main__":
unittest.main()