diff --git a/core/tests/test_trezor.io.flash_area.py b/core/tests/test_trezor.io.flash_area.py new file mode 100644 index 000000000..f780852bc --- /dev/null +++ b/core/tests/test_trezor.io.flash_area.py @@ -0,0 +1,113 @@ +from common import * # isort:skip + +from trezor import io +from trezor.crypto import hashlib + + +class TestTrezorIoFlashArea(unittest.TestCase): + def test_firmware_hash(self): + area = io.flash_area.FIRMWARE + area.erase() + self.assertEqual( + area.hash(0, area.size()), + b"\xd2\xdb\x90\xa7jV6\xa7\x00N\xc3\xb4\x8eq\xa9U\xe0\xcb\xb2\xcbZo\xd7\xae\x9f\xbe\xf8F\xbc\x16l\x8c", + ) + self.assertEqual( + area.hash(0, area.size(), b"0123456789abcdef"), + b"\xa0\x93@\x98\xa6\x80\xdb\x07m\xdf~\xe2'E\xf1\x19\xd8\xfd\xa4`\x10H\xf0_\xdbf\xa6N\xdd\xc0\xcf\xed", + ) + + def test_write(self): + # let's trash the firmware :shrug: + area = io.flash_area.FIRMWARE + size = area.size() + + area.write(0, b"") + area.write(1024, b"") + area.write(size, b"") + with self.assertRaises(ValueError): + area.write(size + 16, b"") + + # fill whole area + area.write(0, b"\x01" * size) + # do it again + area.write(0, b"\x01" * size) + # can't write more + with self.assertRaises(ValueError): + area.write(0, b"\x01" * (size + 16)) + + with self.assertRaises(ValueError): + area.write(1, b"\x01" * size) + + def test_overwrite(self): + area = io.flash_area.FIRMWARE + size = area.size() + + area.erase() + area.write(0, b"\x00" * 1024) + + # try writing the same thing + area.write(0, b"\x00" * 1024) + + # try writing some ones + with self.assertRaises(ValueError): + area.write(0, b"\x01" * 1024) + + def test_read_write(self): + area = io.flash_area.FIRMWARE + size = area.size() + area.erase() + + buf = bytearray(256) + for i in range(256): + buf[i] = i + + for start in range(0, size, 256): + area.write(start, buf) + + all_data = bytearray(size) + area.read(0, all_data) + for i in range(size): + # avoid super-slow assertEqual + if i % 256 != all_data[i]: + self.fail(f"at {i} expected {i % 256}, found {all_data[i]}") + + chunk = bytearray(1024) + for start in range(0, size, 1024): + area.read(start, chunk) + for j in range(1024): + # avoid super-slow assertEqual + if (start + j) % 256 != chunk[j]: + self.fail( + f"at {start + j} expected {(start + j) % 256}, found {chunk[j]}" + ) + + def test_hash(self): + area = io.flash_area.FIRMWARE + size = area.size() + + + all_data = bytearray(size) + for i in range(size): + all_data[i] = i % 256 + + hasher = hashlib.blake2s() + hasher.update(all_data) + digest = hasher.digest() + digest2 = area.hash(0, size) + self.assertEqual(digest, digest2) + + hasher = hashlib.blake2s() + digest = hasher.digest() + digest2 = area.hash(0, 0) + self.assertEqual(digest, digest2) + + hasher = hashlib.blake2s() + hasher.update(all_data[1024:2048]) + digest = hasher.digest() + digest2 = area.hash(1024, 1024) + self.assertEqual(digest, digest2) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.protobuf.py b/core/tests/test_trezor.protobuf.py index 6418819bf..609c9a685 100644 --- a/core/tests/test_trezor.protobuf.py +++ b/core/tests/test_trezor.protobuf.py @@ -2,7 +2,7 @@ from common import * # isort:skip from trezor import protobuf from trezor.messages import ( - DebugLinkMemoryRead, + PrevOutput, Failure, SignMessage, WebAuthnCredential, @@ -19,12 +19,20 @@ def load_uvarint32(data: bytes) -> int: def load_uvarint64(data: bytes) -> int: - # use known uint64 field in an all-optional message - buffer = bytearray(len(data) + 1) - buffer[1:] = data - buffer[0] = (2 << 3) | 0 # field number 1, wire type 0 - msg = protobuf.decode(buffer, DebugLinkMemoryRead, False) - return msg.length + # use known uint64 field + # message PrevOutput { + # required uint64 amount = 1; + # required bytes script_pubkey = 2; + # optional uint32 decred_script_version = 3; + # } + buffer = bytearray(len(data) + 1 + 2) + buffer[1:-2] = data + buffer[0] = (1 << 3) | 0 # field number 1, wire type 0 + # create a zero-length script-pubkey field + buffer[-2] = (2 << 3) | 2 # field number 2, wire type 2 + buffer[-1] = 0 # length of the script-pubkey + msg = protobuf.decode(buffer, PrevOutput, False) + return msg.amount def dump_uvarint32(value: int) -> bytearray: @@ -39,12 +47,14 @@ def dump_uvarint32(value: int) -> bytearray: def dump_uvarint64(value: int) -> bytearray: # use known uint64 field in an all-optional message - msg = DebugLinkMemoryRead(length=value) + msg = PrevOutput(amount=value, script_pubkey=b"") length = protobuf.encoded_length(msg) buffer = bytearray(length) protobuf.encode(buffer, msg) - assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0 - return buffer[1:] + assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0 + assert buffer[-2] == (2 << 3) | 2 # field number 2, wire type 2 + assert buffer[-1] == 0 # length of the script-pubkey + return buffer[1:-2] def dump_message(msg: protobuf.MessageType) -> bytearray: diff --git a/core/tests/test_trezor.utils.py b/core/tests/test_trezor.utils.py index 8f503b2b1..7c73473ee 100644 --- a/core/tests/test_trezor.utils.py +++ b/core/tests/test_trezor.utils.py @@ -52,16 +52,6 @@ class TestUtils(unittest.TestCase): utils.truncate_utf8("\u1234\u5678", 7), "\u1234\u5678" ) # b'\xe1\x88\xb4\xe5\x99\xb8 - def test_firmware_hash(self): - self.assertEqual( - utils.firmware_hash(), - b"\xd2\xdb\x90\xa7jV6\xa7\x00N\xc3\xb4\x8eq\xa9U\xe0\xcb\xb2\xcbZo\xd7\xae\x9f\xbe\xf8F\xbc\x16l\x8c", - ) - self.assertEqual( - utils.firmware_hash(b"0123456789abcdef"), - b"\xa0\x93@\x98\xa6\x80\xdb\x07m\xdf~\xe2'E\xf1\x19\xd8\xfd\xa4`\x10H\xf0_\xdbf\xa6N\xdd\xc0\xcf\xed", - ) - if __name__ == "__main__": unittest.main()