mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-04 13:52:35 +00:00
tests(core): fix tests and add FlashArea unit test
This commit is contained in:
parent
1ee2095c23
commit
dd7ca9ebf1
113
core/tests/test_trezor.io.flash_area.py
Normal file
113
core/tests/test_trezor.io.flash_area.py
Normal file
@ -0,0 +1,113 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
from trezor import io
|
||||
from trezor.crypto import hashlib
|
||||
|
||||
|
||||
class TestTrezorIoFlashArea(unittest.TestCase):
|
||||
def test_firmware_hash(self):
|
||||
area = io.flash_area.FIRMWARE
|
||||
area.erase()
|
||||
self.assertEqual(
|
||||
area.hash(0, area.size()),
|
||||
b"\xd2\xdb\x90\xa7jV6\xa7\x00N\xc3\xb4\x8eq\xa9U\xe0\xcb\xb2\xcbZo\xd7\xae\x9f\xbe\xf8F\xbc\x16l\x8c",
|
||||
)
|
||||
self.assertEqual(
|
||||
area.hash(0, area.size(), b"0123456789abcdef"),
|
||||
b"\xa0\x93@\x98\xa6\x80\xdb\x07m\xdf~\xe2'E\xf1\x19\xd8\xfd\xa4`\x10H\xf0_\xdbf\xa6N\xdd\xc0\xcf\xed",
|
||||
)
|
||||
|
||||
def test_write(self):
|
||||
# let's trash the firmware :shrug:
|
||||
area = io.flash_area.FIRMWARE
|
||||
size = area.size()
|
||||
|
||||
area.write(0, b"")
|
||||
area.write(1024, b"")
|
||||
area.write(size, b"")
|
||||
with self.assertRaises(ValueError):
|
||||
area.write(size + 16, b"")
|
||||
|
||||
# fill whole area
|
||||
area.write(0, b"\x01" * size)
|
||||
# do it again
|
||||
area.write(0, b"\x01" * size)
|
||||
# can't write more
|
||||
with self.assertRaises(ValueError):
|
||||
area.write(0, b"\x01" * (size + 16))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
area.write(1, b"\x01" * size)
|
||||
|
||||
def test_overwrite(self):
|
||||
area = io.flash_area.FIRMWARE
|
||||
size = area.size()
|
||||
|
||||
area.erase()
|
||||
area.write(0, b"\x00" * 1024)
|
||||
|
||||
# try writing the same thing
|
||||
area.write(0, b"\x00" * 1024)
|
||||
|
||||
# try writing some ones
|
||||
with self.assertRaises(ValueError):
|
||||
area.write(0, b"\x01" * 1024)
|
||||
|
||||
def test_read_write(self):
|
||||
area = io.flash_area.FIRMWARE
|
||||
size = area.size()
|
||||
area.erase()
|
||||
|
||||
buf = bytearray(256)
|
||||
for i in range(256):
|
||||
buf[i] = i
|
||||
|
||||
for start in range(0, size, 256):
|
||||
area.write(start, buf)
|
||||
|
||||
all_data = bytearray(size)
|
||||
area.read(0, all_data)
|
||||
for i in range(size):
|
||||
# avoid super-slow assertEqual
|
||||
if i % 256 != all_data[i]:
|
||||
self.fail(f"at {i} expected {i % 256}, found {all_data[i]}")
|
||||
|
||||
chunk = bytearray(1024)
|
||||
for start in range(0, size, 1024):
|
||||
area.read(start, chunk)
|
||||
for j in range(1024):
|
||||
# avoid super-slow assertEqual
|
||||
if (start + j) % 256 != chunk[j]:
|
||||
self.fail(
|
||||
f"at {start + j} expected {(start + j) % 256}, found {chunk[j]}"
|
||||
)
|
||||
|
||||
def test_hash(self):
|
||||
area = io.flash_area.FIRMWARE
|
||||
size = area.size()
|
||||
|
||||
|
||||
all_data = bytearray(size)
|
||||
for i in range(size):
|
||||
all_data[i] = i % 256
|
||||
|
||||
hasher = hashlib.blake2s()
|
||||
hasher.update(all_data)
|
||||
digest = hasher.digest()
|
||||
digest2 = area.hash(0, size)
|
||||
self.assertEqual(digest, digest2)
|
||||
|
||||
hasher = hashlib.blake2s()
|
||||
digest = hasher.digest()
|
||||
digest2 = area.hash(0, 0)
|
||||
self.assertEqual(digest, digest2)
|
||||
|
||||
hasher = hashlib.blake2s()
|
||||
hasher.update(all_data[1024:2048])
|
||||
digest = hasher.digest()
|
||||
digest2 = area.hash(1024, 1024)
|
||||
self.assertEqual(digest, digest2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -2,7 +2,7 @@ from common import * # isort:skip
|
||||
|
||||
from trezor import protobuf
|
||||
from trezor.messages import (
|
||||
DebugLinkMemoryRead,
|
||||
PrevOutput,
|
||||
Failure,
|
||||
SignMessage,
|
||||
WebAuthnCredential,
|
||||
@ -19,12 +19,20 @@ def load_uvarint32(data: bytes) -> int:
|
||||
|
||||
|
||||
def load_uvarint64(data: bytes) -> int:
|
||||
# use known uint64 field in an all-optional message
|
||||
buffer = bytearray(len(data) + 1)
|
||||
buffer[1:] = data
|
||||
buffer[0] = (2 << 3) | 0 # field number 1, wire type 0
|
||||
msg = protobuf.decode(buffer, DebugLinkMemoryRead, False)
|
||||
return msg.length
|
||||
# use known uint64 field
|
||||
# message PrevOutput {
|
||||
# required uint64 amount = 1;
|
||||
# required bytes script_pubkey = 2;
|
||||
# optional uint32 decred_script_version = 3;
|
||||
# }
|
||||
buffer = bytearray(len(data) + 1 + 2)
|
||||
buffer[1:-2] = data
|
||||
buffer[0] = (1 << 3) | 0 # field number 1, wire type 0
|
||||
# create a zero-length script-pubkey field
|
||||
buffer[-2] = (2 << 3) | 2 # field number 2, wire type 2
|
||||
buffer[-1] = 0 # length of the script-pubkey
|
||||
msg = protobuf.decode(buffer, PrevOutput, False)
|
||||
return msg.amount
|
||||
|
||||
|
||||
def dump_uvarint32(value: int) -> bytearray:
|
||||
@ -39,12 +47,14 @@ def dump_uvarint32(value: int) -> bytearray:
|
||||
|
||||
def dump_uvarint64(value: int) -> bytearray:
|
||||
# use known uint64 field in an all-optional message
|
||||
msg = DebugLinkMemoryRead(length=value)
|
||||
msg = PrevOutput(amount=value, script_pubkey=b"")
|
||||
length = protobuf.encoded_length(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.encode(buffer, msg)
|
||||
assert buffer[0] == (2 << 3) | 0 # field number 2, wire type 0
|
||||
return buffer[1:]
|
||||
assert buffer[0] == (1 << 3) | 0 # field number 1, wire type 0
|
||||
assert buffer[-2] == (2 << 3) | 2 # field number 2, wire type 2
|
||||
assert buffer[-1] == 0 # length of the script-pubkey
|
||||
return buffer[1:-2]
|
||||
|
||||
|
||||
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||
|
@ -52,16 +52,6 @@ class TestUtils(unittest.TestCase):
|
||||
utils.truncate_utf8("\u1234\u5678", 7), "\u1234\u5678"
|
||||
) # b'\xe1\x88\xb4\xe5\x99\xb8
|
||||
|
||||
def test_firmware_hash(self):
|
||||
self.assertEqual(
|
||||
utils.firmware_hash(),
|
||||
b"\xd2\xdb\x90\xa7jV6\xa7\x00N\xc3\xb4\x8eq\xa9U\xe0\xcb\xb2\xcbZo\xd7\xae\x9f\xbe\xf8F\xbc\x16l\x8c",
|
||||
)
|
||||
self.assertEqual(
|
||||
utils.firmware_hash(b"0123456789abcdef"),
|
||||
b"\xa0\x93@\x98\xa6\x80\xdb\x07m\xdf~\xe2'E\xf1\x19\xd8\xfd\xa4`\x10H\xf0_\xdbf\xa6N\xdd\xc0\xcf\xed",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user