mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 09:28:13 +00:00
168 lines
7.4 KiB
Python
168 lines
7.4 KiB
Python
import sys
|
|
|
|
sys.path.append('../src')
|
|
sys.path.append('../src/lib')
|
|
|
|
from utest import *
|
|
from ustruct import pack, unpack
|
|
from ubinascii import hexlify, unhexlify
|
|
|
|
from trezor import msg
|
|
from trezor.loop import Select, Syscall, READ, WRITE
|
|
from trezor.utils import chunks
|
|
from trezor.wire import codec_v2
|
|
|
|
|
|
def test_reader():
|
|
rep_len = 64
|
|
interface = 0xdeadbeef
|
|
session_id = 0x12345678
|
|
message_type = 0x87654321
|
|
message_len = 250
|
|
reader = codec_v2.Reader(interface, session_id)
|
|
|
|
message = bytearray(range(message_len))
|
|
report_header = bytearray(unhexlify('011234567887654321000000fa'))
|
|
|
|
# open, expected one read
|
|
first_report = report_header + message[:rep_len - len(report_header)]
|
|
assert_async(reader.open(), [(None, Select(READ | interface)), (first_report, StopIteration()),])
|
|
assert_eq(reader.type, message_type)
|
|
assert_eq(reader.size, message_len)
|
|
|
|
# empty read
|
|
empty_buffer = bytearray()
|
|
assert_async(reader.readinto(empty_buffer), [(None, StopIteration()),])
|
|
assert_eq(len(empty_buffer), 0)
|
|
assert_eq(reader.size, message_len)
|
|
|
|
# short read, expected no read
|
|
short_buffer = bytearray(32)
|
|
assert_async(reader.readinto(short_buffer), [(None, StopIteration()),])
|
|
assert_eq(len(short_buffer), 32)
|
|
assert_eq(short_buffer, message[:len(short_buffer)])
|
|
assert_eq(reader.size, message_len - len(short_buffer))
|
|
|
|
# aligned read, expected no read
|
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
|
assert_async(reader.readinto(aligned_buffer), [(None, StopIteration()),])
|
|
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
|
|
|
# one byte read, expected one read
|
|
next_report_header = bytearray(unhexlify('021234567800000000'))
|
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
|
onebyte_buffer = bytearray(1)
|
|
assert_async(reader.readinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),])
|
|
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
|
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
|
|
|
# too long read, raises eof
|
|
assert_async(reader.readinto(bytearray(reader.size + 1)), [(None, EOFError()),])
|
|
|
|
# long read, expect multiple reads
|
|
start_size = reader.size
|
|
long_buffer = bytearray(start_size)
|
|
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
|
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
|
report_payload_rest = report_payload[len(report_payload_head):]
|
|
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
|
report_payloads = [report_payload_head] + report_payload_rest
|
|
next_reports = [bytearray(unhexlify('0212345678') + pack('>L', i + 1)) + r for i, r in enumerate(report_payloads)]
|
|
expected_syscalls = []
|
|
for i, _ in enumerate(next_reports):
|
|
prev_report = next_reports[i - 1] if i > 0 else None
|
|
expected_syscalls.append((prev_report, Select(READ | interface)))
|
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
|
assert_async(reader.readinto(long_buffer), expected_syscalls)
|
|
assert_eq(long_buffer, message[-start_size:])
|
|
assert_eq(reader.size, 0)
|
|
|
|
# one byte read, raises eof
|
|
assert_async(reader.readinto(onebyte_buffer), [(None, EOFError()),])
|
|
|
|
|
|
def test_writer():
|
|
rep_len = 64
|
|
interface = 0xdeadbeef
|
|
session_id = 0x12345678
|
|
message_type = 0x87654321
|
|
message_len = 1024
|
|
writer = codec_v2.Writer(interface, session_id, message_type, message_len)
|
|
|
|
# init header corresponding to the data above
|
|
report_header = bytearray(unhexlify('01123456788765432100000400'))
|
|
|
|
assert_eq(writer.data, report_header + bytearray(64 - len(report_header)))
|
|
|
|
# empty write
|
|
start_size = writer.size
|
|
assert_async(writer.write(bytearray()), [(None, StopIteration()),])
|
|
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
|
assert_eq(writer.size, start_size)
|
|
|
|
# short write, expected no report
|
|
start_size = writer.size
|
|
short_payload = bytearray(range(4))
|
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
|
assert_eq(writer.size, start_size - len(short_payload))
|
|
assert_eq(writer.data,
|
|
report_header
|
|
+ short_payload
|
|
+ bytearray(rep_len - len(report_header) - len(short_payload)))
|
|
|
|
# aligned write, expected one report
|
|
start_size = writer.size
|
|
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
|
msg.send = mock_call(msg.send, [
|
|
(interface, report_header
|
|
+ short_payload
|
|
+ aligned_payload
|
|
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ])
|
|
assert_async(writer.write(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
|
|
assert_eq(writer.size, start_size - len(aligned_payload))
|
|
msg.send.assert_called_n_times(1)
|
|
msg.send = msg.send.original
|
|
|
|
# short write, expected no report, but data starts with correct seq and cont marker
|
|
report_header = bytearray(unhexlify('021234567800000000'))
|
|
start_size = writer.size
|
|
assert_async(writer.write(short_payload), [(None, StopIteration()),])
|
|
assert_eq(writer.size, start_size - len(short_payload))
|
|
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
|
report_header + short_payload)
|
|
|
|
# long write, expected multiple reports
|
|
start_size = writer.size
|
|
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
|
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
|
long_payload = long_payload_head + long_payload_rest
|
|
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
|
expected_reports = [
|
|
bytearray(unhexlify('0212345678') + pack('>L', seq)) + rep
|
|
for seq, rep in enumerate(expected_payloads)]
|
|
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
|
# test write
|
|
expected_write_reports = expected_reports[:-1]
|
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports])
|
|
assert_async(writer.write(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
|
assert_eq(writer.size, start_size - len(long_payload))
|
|
msg.send.assert_called_n_times(len(expected_write_reports))
|
|
msg.send = msg.send.original
|
|
# test write raises eof
|
|
msg.send = mock_call(msg.send, [])
|
|
assert_async(writer.write(bytearray(1)), [(None, EOFError())])
|
|
msg.send.assert_called_n_times(0)
|
|
msg.send = msg.send.original
|
|
# test close
|
|
expected_close_reports = expected_reports[-1:]
|
|
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports])
|
|
assert_async(writer.close(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())])
|
|
assert_eq(writer.size, 0)
|
|
msg.send.assert_called_n_times(len(expected_close_reports))
|
|
msg.send = msg.send.original
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|