You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/tests/test_trezor.wire.codec_v1.py

182 lines
6.3 KiB

from common import *
from ubinascii import unhexlify
import ustruct
from trezor import io
from trezor.loop import wait
from trezor.utils import chunks, BufferIO
from trezor.wire import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
def make_header(mtype, length):
return b"?##" + ustruct.pack(">HL", mtype, length)
class TestWireCodecV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_read_one_packet(self):
# zero length message - just a header
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
buffer = bytearray(64)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(message_packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, b"")
# message should have been read into the buffer
self.assertEqual(buffer, b"\x00" * 64)
def test_read_many_packets(self):
message = bytes(range(256))
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
]
buffer = bytearray(256)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
# message should have been read into the buffer
self.assertEqual(buffer, message)
def test_read_large_message(self):
message = b"hello world"
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
packet = header + message
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= codec_v1._REP_LEN)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"")
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(mtype=MESSAGE_TYPE, length=0)
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
assert result.type == MESSAGE_TYPE
assert result.data.buffer == message
if __name__ == "__main__":
unittest.main()