mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-31 18:40:56 +00:00
wire: add tests, fix missing 0-padding
This commit is contained in:
parent
e62e8dbe6f
commit
d0b29d4caa
@ -1,87 +0,0 @@
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
sys.path.append('../lib')
|
||||
import unittest
|
||||
|
||||
from trezor import loop
|
||||
from trezor import msg
|
||||
from trezor.wire import read_wire_msg, write_wire_msg
|
||||
from trezor.utils import chunks
|
||||
|
||||
|
||||
class TestWire(unittest.TestCase):
|
||||
|
||||
def test_read_wire_msg(self):
|
||||
|
||||
# Reading empty message returns correct type and empty bytes
|
||||
|
||||
reader = read_wire_msg()
|
||||
reader.send(None)
|
||||
|
||||
empty_message = b'\x3f##\xab\xcd\x00\x00\x00\x00' + b'\x00' * 55
|
||||
try:
|
||||
reader.send((empty_message,))
|
||||
except StopIteration as e:
|
||||
restype, resmsg = e.value
|
||||
self.assertEqual(restype, int('0xabcd', 16))
|
||||
self.assertEqual(resmsg, b'')
|
||||
|
||||
# Reading message from one report
|
||||
|
||||
reader = read_wire_msg()
|
||||
reader.send(None)
|
||||
|
||||
content = bytes([x for x in range(0, 55)])
|
||||
message = b'\x3f##\xab\xcd\x00\x00\x00\x37' + content
|
||||
try:
|
||||
reader.send((message,))
|
||||
except StopIteration as e:
|
||||
restype, resmsg = e.value
|
||||
self.assertEqual(restype, int('0xabcd', 16))
|
||||
self.assertEqual(resmsg, content)
|
||||
|
||||
# Reading message spanning multiple reports
|
||||
|
||||
reader = read_wire_msg()
|
||||
reader.send(None)
|
||||
|
||||
content = bytes([x for x in range(0, 256)])
|
||||
message = b'##\xab\xcd\x00\x00\x01\00' + content
|
||||
reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)]
|
||||
try:
|
||||
for report in reports:
|
||||
reader.send((report,))
|
||||
except StopIteration as e:
|
||||
restype, resmsg = e.value
|
||||
self.assertEqual(restype, int('0xabcd', 16))
|
||||
self.assertEqual(resmsg, content)
|
||||
|
||||
def test_write_wire_msg(self):
|
||||
|
||||
# Writing message spanning multiple reports calls msg.send() with correct data
|
||||
|
||||
sent_reps = []
|
||||
|
||||
def dummy_send(iface, rep):
|
||||
sent_reps.append(bytes(rep))
|
||||
return len(rep)
|
||||
|
||||
msg.send = dummy_send
|
||||
|
||||
content = bytes([x for x in range(0, 256)])
|
||||
message = b'##\xab\xcd\x00\x00\x01\00' + content
|
||||
reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)]
|
||||
|
||||
writer = write_wire_msg(int('0xabcd'), content)
|
||||
res = 1 # Something not None
|
||||
try:
|
||||
while True:
|
||||
writer.send(None)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(sent_reps, reports)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
250
src/tests/test_wire_codec.py
Normal file
250
src/tests/test_wire_codec.py
Normal file
@ -0,0 +1,250 @@
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
sys.path.append('../lib')
|
||||
import unittest
|
||||
import ustruct
|
||||
import ubinascii
|
||||
|
||||
from trezor.wire import wire_codec
|
||||
from trezor.utils import chunks
|
||||
from trezor.crypto import random
|
||||
|
||||
|
||||
class TestWireCodec(unittest.TestCase):
|
||||
# pylint: disable=C0301
|
||||
|
||||
def test_parse(self):
|
||||
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
||||
|
||||
m, s, d = wire_codec.parse_report(d)
|
||||
self.assertEqual(m, b'O'[0])
|
||||
self.assertEqual(s, 0x01234567)
|
||||
self.assertEqual(d, bytes(range(0, 59)))
|
||||
|
||||
t, l, d = wire_codec.parse_message(d)
|
||||
self.assertEqual(t, 0x00010203)
|
||||
self.assertEqual(l, 0x04050607)
|
||||
self.assertEqual(d, bytes(range(8, 59)))
|
||||
|
||||
f, = wire_codec.parse_message_footer(d[0:4])
|
||||
self.assertEqual(f, 0x08090a0b)
|
||||
|
||||
for i in range(0, 1024):
|
||||
if i != 64:
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.parse_report(bytes(range(0, i)))
|
||||
if i != 59:
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.parse_message(bytes(range(0, i)))
|
||||
if i != 4:
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.parse_message_footer(bytes(range(0, i)))
|
||||
|
||||
def test_serialize(self):
|
||||
data = bytearray(range(0, 6))
|
||||
wire_codec.serialize_report_header(data, 0x12, 0x3456789a)
|
||||
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
||||
|
||||
data = bytearray(range(0, 6))
|
||||
wire_codec.serialize_opened_session(data, 0x3456789a)
|
||||
self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
||||
|
||||
data = bytearray(range(0, 14))
|
||||
wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
||||
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
||||
|
||||
data = bytearray(range(0, 5))
|
||||
wire_codec.serialize_message_footer(data, 0x89abcdef)
|
||||
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
||||
|
||||
for i in range(0, 13):
|
||||
data = bytearray(i)
|
||||
if i < 4:
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.serialize_message_footer(data, 0x00)
|
||||
if i < 5:
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.serialize_report_header(data, 0x00, 0x00)
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.serialize_opened_session(data, 0x00)
|
||||
with self.assertRaises(ValueError):
|
||||
wire_codec.serialize_message_header(data, 0x00, 0x00)
|
||||
|
||||
def test_decode_empty(self):
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
|
||||
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertIsInstance(record[0], EOFError)
|
||||
|
||||
def test_decode_one_report_aligned_correct(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = b'\x2f\x1c\x12\xce'
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
||||
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 2)
|
||||
self.assertEqual(record[0], data)
|
||||
self.assertIsInstance(record[1], EOFError)
|
||||
|
||||
def test_decode_one_report_aligned_incorrect(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = bytes(4) # wrong checksum
|
||||
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
||||
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
try:
|
||||
decoder.send(message)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), 2)
|
||||
self.assertEqual(record[0], data)
|
||||
self.assertIsInstance(record[1], wire_codec.MessageChecksumError)
|
||||
|
||||
def test_decode_generated_range(self):
|
||||
for data_len in range(1, 512):
|
||||
data = random.bytes(data_len)
|
||||
data_chunks = [data[:51]] + list(chunks(data[51:], 59))
|
||||
|
||||
msg_type = 0xabcdef12
|
||||
data_csum = ubinascii.crc32(data)
|
||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
||||
footer = ustruct.pack('>L', data_csum)
|
||||
|
||||
message = header + data + footer
|
||||
message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))]
|
||||
|
||||
record = []
|
||||
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
||||
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||
decoder.send(None)
|
||||
|
||||
res = 1
|
||||
try:
|
||||
for c in message_chunks:
|
||||
decoder.send(c)
|
||||
except StopIteration as e:
|
||||
res = e.value
|
||||
self.assertEqual(res, None)
|
||||
self.assertEqual(len(record), len(data_chunks) + 1)
|
||||
for i in range(0, len(data_chunks)):
|
||||
self.assertEqual(record[i], data_chunks[i])
|
||||
self.assertIsInstance(record[-1], EOFError)
|
||||
|
||||
def test_encode_empty(self):
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
|
||||
self.assertEqual(len(record), 1)
|
||||
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
||||
|
||||
def test_encode_one_report_aligned(self):
|
||||
data = bytes(range(0, 47))
|
||||
footer = b'\x2f\x1c\x12\xce'
|
||||
|
||||
record = []
|
||||
target = self._record(record)()
|
||||
target.send(None)
|
||||
|
||||
wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
|
||||
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
||||
|
||||
def test_encode_generated_range(self):
|
||||
for data_len in range(1, 1024):
|
||||
data = random.bytes(data_len)
|
||||
|
||||
msg_type = 0xabcdef12
|
||||
session_id = 0xdeadbeef
|
||||
|
||||
data_csum = ubinascii.crc32(data)
|
||||
header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len)
|
||||
footer = ustruct.pack('>L', data_csum)
|
||||
session_header = ustruct.pack('>L', session_id)
|
||||
|
||||
message = header + data + footer
|
||||
report0 = b'H' + session_header + message[:59]
|
||||
reports = [b'D' + session_header + c for c in chunks(message[59:], 59)]
|
||||
reports.insert(0, report0)
|
||||
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
||||
|
||||
received = 0
|
||||
def genfunc():
|
||||
nonlocal received
|
||||
while True:
|
||||
self.assertEqual((yield), reports[received])
|
||||
received += 1
|
||||
target = genfunc()
|
||||
target.send(None)
|
||||
|
||||
wire_codec.encode_wire_message(msg_type, data, session_id, target)
|
||||
self.assertEqual(received, len(reports))
|
||||
|
||||
def _record(self, record, *_args):
|
||||
def genfunc(*args):
|
||||
self.assertEqual(args, _args)
|
||||
while True:
|
||||
try:
|
||||
v = yield
|
||||
except Exception as e:
|
||||
record.append(e)
|
||||
else:
|
||||
record.append(v)
|
||||
return genfunc
|
||||
|
||||
|
||||
# def test_write_wire_msg(self):
|
||||
|
||||
# # Writing message spanning multiple reports calls msg.send() with
|
||||
# # correct data
|
||||
|
||||
# sent_reps = []
|
||||
|
||||
# def dummy_send(iface, rep):
|
||||
# sent_reps.append(bytes(rep))
|
||||
# return len(rep)
|
||||
|
||||
# msg.send = dummy_send
|
||||
|
||||
# content = bytes([x for x in range(0, 256)])
|
||||
# message = b'##\xab\xcd\x00\x00\x01\00' + content
|
||||
# reports = [b'\x3f' + ch + '\x00' *
|
||||
# (63 - len(ch)) for ch in chunks(message, 63)]
|
||||
|
||||
# writer = write_wire_msg(int('0xabcd'), content)
|
||||
# res = 1 # Something not None
|
||||
# try:
|
||||
# while True:
|
||||
# writer.send(None)
|
||||
# except StopIteration as e:
|
||||
# res = e.value
|
||||
# self.assertEqual(res, None)
|
||||
# self.assertEqual(sent_reps, reports)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -35,29 +35,41 @@ _MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
|
||||
|
||||
|
||||
def parse_report(data):
|
||||
if len(data) != _REP_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
||||
return marker, session_id, data[_REP_HEADER_LEN:]
|
||||
|
||||
|
||||
def parse_message(data):
|
||||
if len(data) != _REP_LEN - _REP_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
||||
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
||||
|
||||
|
||||
def parse_message_footer(data):
|
||||
if len(data) != _MSG_FOOTER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
|
||||
return data_checksum,
|
||||
|
||||
|
||||
def serialize_report_header(data, marker, session_id):
|
||||
if len(data) < _REP_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
|
||||
|
||||
|
||||
def serialize_message_header(data, msg_type, msg_len):
|
||||
if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
|
||||
|
||||
|
||||
def serialize_message_footer(data, checksum):
|
||||
if len(data) < _MSG_FOOTER_LEN:
|
||||
raise ValueError('Invalid buffer size')
|
||||
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
|
||||
|
||||
|
||||
@ -112,6 +124,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
||||
|
||||
|
||||
def encode_wire_message(msg_type, msg_data, session_id, target):
|
||||
'''Encode a full wire message directly to reports and stream it to target.
|
||||
|
||||
Target receives `memoryview`s of HID reports which are valid until the targets
|
||||
`send()` method returns.
|
||||
'''
|
||||
report = memoryview(bytearray(_REP_LEN))
|
||||
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
||||
serialize_message_header(report, msg_type, len(msg_data))
|
||||
@ -139,7 +156,7 @@ def encode_wire_message(msg_type, msg_data, session_id, target):
|
||||
msg_footer = None
|
||||
continue
|
||||
|
||||
# FIXME: optimize speed
|
||||
# fill the rest of the report with 0x00
|
||||
x = 0
|
||||
to_fill = len(target_data)
|
||||
while x < to_fill:
|
||||
@ -154,8 +171,8 @@ def encode_wire_message(msg_type, msg_data, session_id, target):
|
||||
# reset to skip the magic and session ID
|
||||
if first:
|
||||
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
||||
target_data = report[_REP_HEADER_LEN:]
|
||||
first = False
|
||||
target_data = report[_REP_HEADER_LEN:]
|
||||
|
||||
|
||||
def encode_session_open_message(session_id, target):
|
||||
|
Loading…
Reference in New Issue
Block a user