1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 09:28:13 +00:00

tests: fix wire tests, remove msg

This commit is contained in:
Jan Pochyla 2017-08-15 15:30:23 +02:00
parent 3562ffdc54
commit 520de105a6
3 changed files with 65 additions and 63 deletions

View File

@ -1,13 +0,0 @@
from common import *
from trezor.crypto import random
from trezor import msg
class TestMsg(unittest.TestCase):
def test_usb(self):
pass
if __name__ == '__main__':
unittest.main()

View File

@ -1,32 +1,45 @@
import sys import sys
sys.path.append('../src') sys.path.append('../src')
sys.path.append('../src/lib')
from utest import * from utest import *
from ustruct import pack, unpack from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE from trezor.loop import Select, Syscall, READ, WRITE
from trezor.crypto import random from trezor.crypto import random
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec_v1 from trezor.wire import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def test_reader(): def test_reader():
rep_len = 64 rep_len = 64
interface = 0xdeadbeef interface_num = 0xdeadbeef
message_type = 0x4321 message_type = 0x4321
message_len = 250 message_len = 250
reader = codec_v1.Reader(interface, codec_v1.SESSION_ID) interface = MockHID(interface_num)
reader = codec_v1.Reader(interface)
message = bytearray(range(message_len)) message = bytearray(range(message_len))
report_header = bytearray(unhexlify('3f23234321000000fa')) report_header = bytearray(unhexlify('3f23234321000000fa'))
# open, expected one read # open, expected one read
first_report = report_header + message[:rep_len - len(report_header)] first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_async(reader.aopen(), [(None, Select(READ | interface_num)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
@ -53,7 +66,7 @@ def test_reader():
next_report_header = bytearray(unhexlify('3f')) next_report_header = bytearray(unhexlify('3f'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1) onebyte_buffer = bytearray(1)
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface_num)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
@ -72,7 +85,7 @@ def test_reader():
expected_syscalls = [] expected_syscalls = []
for i, _ in enumerate(next_reports): for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((prev_report, Select(READ | interface_num)))
expected_syscalls.append((next_reports[-1], StopIteration())) expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.areadinto(long_buffer), expected_syscalls) assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:]) assert_eq(long_buffer, message[-start_size:])
@ -84,10 +97,11 @@ def test_reader():
def test_writer(): def test_writer():
rep_len = 64 rep_len = 64
interface = 0xdeadbeef interface_num = 0xdeadbeef
message_type = 0x87654321 message_type = 0x87654321
message_len = 1024 message_len = 1024
writer = codec_v1.Writer(interface, codec_v1.SESSION_ID) interface = MockHID(interface_num)
writer = codec_v1.Writer(interface)
writer.setheader(message_type, message_len) writer.setheader(message_type, message_len)
# init header corresponding to the data above # init header corresponding to the data above
@ -114,15 +128,13 @@ def test_writer():
# aligned write, expected one report # aligned write, expected one report
start_size = writer.size start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
msg.send = mock_call(msg.send, [ assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface_num)), (None, StopIteration()),])
(interface, report_header assert_eq(interface.data, [report_header
+ short_payload + short_payload
+ aligned_payload + aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload)) assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1) interface.data.clear()
msg.send = msg.send.original
# short write, expected no report, but data starts with correct seq and cont marker # short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('3f')) report_header = bytearray(unhexlify('3f'))
@ -142,23 +154,18 @@ def test_writer():
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write # test write
expected_write_reports = expected_reports[:-1] expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface_num))] + [(None, StopIteration())])
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(interface.data, expected_write_reports)
assert_eq(writer.size, start_size - len(long_payload)) assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports)) interface.data.clear()
msg.send = msg.send.original
# test write raises eof # test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())]) assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0) assert_eq(interface.data, [])
msg.send = msg.send.original
# test close # test close
expected_close_reports = expected_reports[-1:] expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface_num))] + [(None, StopIteration())])
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(interface.data, expected_close_reports)
assert_eq(writer.size, 0) assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports))
msg.send = msg.send.original
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,24 +1,37 @@
import sys import sys
sys.path.append('../src') sys.path.append('../src')
sys.path.append('../src/lib')
from utest import * from utest import *
from ustruct import pack, unpack from ustruct import pack, unpack
from ubinascii import hexlify, unhexlify from ubinascii import hexlify, unhexlify
from trezor import msg
from trezor.loop import Select, Syscall, READ, WRITE from trezor.loop import Select, Syscall, READ, WRITE
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec_v2 from trezor.wire import codec_v2
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def test_reader(): def test_reader():
rep_len = 64 rep_len = 64
interface = 0xdeadbeef interface_num = 0xdeadbeef
session_id = 0x12345678 session_id = 0x12345678
message_type = 0x87654321 message_type = 0x87654321
message_len = 250 message_len = 250
interface = MockHID(interface_num)
reader = codec_v2.Reader(interface, session_id) reader = codec_v2.Reader(interface, session_id)
message = bytearray(range(message_len)) message = bytearray(range(message_len))
@ -26,7 +39,7 @@ def test_reader():
# open, expected one read # open, expected one read
first_report = report_header + message[:rep_len - len(report_header)] first_report = report_header + message[:rep_len - len(report_header)]
assert_async(reader.aopen(), [(None, Select(READ | interface)), (first_report, StopIteration()),]) assert_async(reader.aopen(), [(None, Select(READ | interface_num)), (first_report, StopIteration()),])
assert_eq(reader.type, message_type) assert_eq(reader.type, message_type)
assert_eq(reader.size, message_len) assert_eq(reader.size, message_len)
@ -53,7 +66,7 @@ def test_reader():
next_report_header = bytearray(unhexlify('021234567800000000')) next_report_header = bytearray(unhexlify('021234567800000000'))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1) onebyte_buffer = bytearray(1)
assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface)), (next_report, StopIteration()),]) assert_async(reader.areadinto(onebyte_buffer), [(None, Select(READ | interface_num)), (next_report, StopIteration()),])
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
@ -72,7 +85,7 @@ def test_reader():
expected_syscalls = [] expected_syscalls = []
for i, _ in enumerate(next_reports): for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, Select(READ | interface))) expected_syscalls.append((prev_report, Select(READ | interface_num)))
expected_syscalls.append((next_reports[-1], StopIteration())) expected_syscalls.append((next_reports[-1], StopIteration()))
assert_async(reader.areadinto(long_buffer), expected_syscalls) assert_async(reader.areadinto(long_buffer), expected_syscalls)
assert_eq(long_buffer, message[-start_size:]) assert_eq(long_buffer, message[-start_size:])
@ -84,10 +97,11 @@ def test_reader():
def test_writer(): def test_writer():
rep_len = 64 rep_len = 64
interface = 0xdeadbeef interface_num = 0xdeadbeef
session_id = 0x12345678 session_id = 0x12345678
message_type = 0x87654321 message_type = 0x87654321
message_len = 1024 message_len = 1024
interface = MockHID(interface_num)
writer = codec_v2.Writer(interface, session_id) writer = codec_v2.Writer(interface, session_id)
writer.setheader(message_type, message_len) writer.setheader(message_type, message_len)
@ -115,15 +129,13 @@ def test_writer():
# aligned write, expected one report # aligned write, expected one report
start_size = writer.size start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
msg.send = mock_call(msg.send, [ assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface_num)), (None, StopIteration()),])
(interface, report_header assert_eq(interface.data, [report_header
+ short_payload + short_payload
+ aligned_payload + aligned_payload
+ bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload))), ]) + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
assert_async(writer.awrite(aligned_payload), [(None, Select(WRITE | interface)), (None, StopIteration()),])
assert_eq(writer.size, start_size - len(aligned_payload)) assert_eq(writer.size, start_size - len(aligned_payload))
msg.send.assert_called_n_times(1) interface.data.clear()
msg.send = msg.send.original
# short write, expected no report, but data starts with correct seq and cont marker # short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('021234567800000000')) report_header = bytearray(unhexlify('021234567800000000'))
@ -145,23 +157,19 @@ def test_writer():
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write # test write
expected_write_reports = expected_reports[:-1] expected_write_reports = expected_reports[:-1]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_write_reports]) assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface_num))] + [(None, StopIteration())])
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(interface.data, expected_write_reports)
assert_eq(writer.size, start_size - len(long_payload)) assert_eq(writer.size, start_size - len(long_payload))
msg.send.assert_called_n_times(len(expected_write_reports)) interface.data.clear()
msg.send = msg.send.original
# test write raises eof # test write raises eof
msg.send = mock_call(msg.send, [])
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())]) assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
msg.send.assert_called_n_times(0) assert_eq(interface.data, [])
msg.send = msg.send.original
# test close # test close
expected_close_reports = expected_reports[-1:] expected_close_reports = expected_reports[-1:]
msg.send = mock_call(msg.send, [(interface, rep) for rep in expected_close_reports]) assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface_num))] + [(None, StopIteration())])
assert_async(writer.aclose(), len(expected_close_reports) * [(None, Select(WRITE | interface))] + [(None, StopIteration())]) assert_eq(interface.data, expected_close_reports)
assert_eq(writer.size, 0) assert_eq(writer.size, 0)
msg.send.assert_called_n_times(len(expected_close_reports)) interface.data.clear()
msg.send = msg.send.original
if __name__ == '__main__': if __name__ == '__main__':