mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 17:38:39 +00:00
wire: add tests for v1 codec
This commit is contained in:
parent
7404a76aeb
commit
507d9bdf68
180
src/tests/test_wire_codec_v1.py
Normal file
180
src/tests/test_wire_codec_v1.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
sys.path.append('../lib')
|
||||||
|
import unittest
|
||||||
|
import ustruct
|
||||||
|
|
||||||
|
from trezor.wire import wire_codec_v1
|
||||||
|
from trezor.utils import chunks
|
||||||
|
from trezor.crypto import random
|
||||||
|
|
||||||
|
|
||||||
|
class TestWireCodecV1(unittest.TestCase):
|
||||||
|
# pylint: disable=C0301
|
||||||
|
|
||||||
|
def test_detect(self):
|
||||||
|
for i in range(0, 256):
|
||||||
|
if i == ord(b'?'):
|
||||||
|
self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
|
||||||
|
else:
|
||||||
|
self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
|
||||||
|
|
||||||
|
def test_parse(self):
|
||||||
|
d = bytes(range(0, 55))
|
||||||
|
m = b'##\x00\x00\x00\x00\x00\x37' + d
|
||||||
|
r = b'?' + m
|
||||||
|
|
||||||
|
rm, rs, rd = wire_codec_v1.parse_report_v1(r)
|
||||||
|
self.assertEqual(rm, None)
|
||||||
|
self.assertEqual(rs, 0)
|
||||||
|
self.assertEqual(rd, m)
|
||||||
|
|
||||||
|
mt, ml, md = wire_codec_v1.parse_message(m)
|
||||||
|
self.assertEqual(mt, 0)
|
||||||
|
self.assertEqual(ml, len(d))
|
||||||
|
self.assertEqual(md, d)
|
||||||
|
|
||||||
|
for i in range(0, 1024):
|
||||||
|
if i != 64:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
wire_codec_v1.parse_report_v1(bytes(range(0, i)))
|
||||||
|
|
||||||
|
for hx in range(0, 256):
|
||||||
|
for hy in range(0, 256):
|
||||||
|
if hx != ord(b'#') and hy != ord(b'#'):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:])
|
||||||
|
|
||||||
|
def test_serialize(self):
|
||||||
|
data = bytearray(range(0, 10))
|
||||||
|
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
||||||
|
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
|
||||||
|
|
||||||
|
data = bytearray(9)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
wire_codec_v1.serialize_message_header(data, 65536, 0)
|
||||||
|
|
||||||
|
for i in range(0, 8):
|
||||||
|
data = bytearray(i)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
|
||||||
|
|
||||||
|
def test_decode_empty(self):
|
||||||
|
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
|
||||||
|
|
||||||
|
record = []
|
||||||
|
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
|
||||||
|
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||||
|
decoder.send(None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoder.send(message)
|
||||||
|
except StopIteration as e:
|
||||||
|
res = e.value
|
||||||
|
self.assertEqual(res, None)
|
||||||
|
self.assertEqual(len(record), 1)
|
||||||
|
self.assertIsInstance(record[0], EOFError)
|
||||||
|
|
||||||
|
def test_decode_one_report_aligned(self):
|
||||||
|
data = bytes(range(0, 55))
|
||||||
|
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
|
||||||
|
|
||||||
|
record = []
|
||||||
|
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
|
||||||
|
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||||
|
decoder.send(None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoder.send(message)
|
||||||
|
except StopIteration as e:
|
||||||
|
res = e.value
|
||||||
|
self.assertEqual(res, None)
|
||||||
|
self.assertEqual(len(record), 2)
|
||||||
|
self.assertEqual(record[0], data)
|
||||||
|
self.assertIsInstance(record[1], EOFError)
|
||||||
|
|
||||||
|
def test_decode_generated_range(self):
|
||||||
|
for data_len in range(1, 512):
|
||||||
|
data = random.bytes(data_len)
|
||||||
|
data_chunks = [data[:55]] + list(chunks(data[55:], 63))
|
||||||
|
|
||||||
|
msg_type = 0xabcd
|
||||||
|
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
||||||
|
|
||||||
|
message = header + data
|
||||||
|
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
|
||||||
|
|
||||||
|
record = []
|
||||||
|
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
||||||
|
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
|
||||||
|
decoder.send(None)
|
||||||
|
|
||||||
|
res = 1
|
||||||
|
try:
|
||||||
|
for c in message_chunks:
|
||||||
|
decoder.send(c)
|
||||||
|
except StopIteration as e:
|
||||||
|
res = e.value
|
||||||
|
self.assertEqual(res, None)
|
||||||
|
self.assertEqual(len(record), len(data_chunks) + 1)
|
||||||
|
for i in range(0, len(data_chunks)):
|
||||||
|
self.assertEqual(record[i], data_chunks[i])
|
||||||
|
self.assertIsInstance(record[-1], EOFError)
|
||||||
|
|
||||||
|
def test_encode_empty(self):
|
||||||
|
record = []
|
||||||
|
target = self._record(record)()
|
||||||
|
target.send(None)
|
||||||
|
|
||||||
|
wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target)
|
||||||
|
self.assertEqual(len(record), 1)
|
||||||
|
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
|
||||||
|
|
||||||
|
def test_encode_one_report_aligned(self):
|
||||||
|
data = bytes(range(0, 55))
|
||||||
|
|
||||||
|
record = []
|
||||||
|
target = self._record(record)()
|
||||||
|
target.send(None)
|
||||||
|
|
||||||
|
wire_codec_v1.encode_wire_v1_message(0xabcd, data, target)
|
||||||
|
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
|
||||||
|
|
||||||
|
def test_encode_generated_range(self):
|
||||||
|
for data_len in range(1, 1024):
|
||||||
|
data = random.bytes(data_len)
|
||||||
|
|
||||||
|
msg_type = 0xabcd
|
||||||
|
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
|
||||||
|
|
||||||
|
message = header + data
|
||||||
|
reports = [b'?' + c for c in chunks(message, 63)]
|
||||||
|
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
|
||||||
|
|
||||||
|
received = 0
|
||||||
|
def genfunc():
|
||||||
|
nonlocal received
|
||||||
|
while True:
|
||||||
|
self.assertEqual((yield), reports[received])
|
||||||
|
received += 1
|
||||||
|
target = genfunc()
|
||||||
|
target.send(None)
|
||||||
|
|
||||||
|
wire_codec_v1.encode_wire_v1_message(msg_type, data, target)
|
||||||
|
self.assertEqual(received, len(reports))
|
||||||
|
|
||||||
|
def _record(self, record, *_args):
|
||||||
|
def genfunc(*args):
|
||||||
|
self.assertEqual(args, _args)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
v = yield
|
||||||
|
except Exception as e:
|
||||||
|
record.append(e)
|
||||||
|
else:
|
||||||
|
record.append(v)
|
||||||
|
return genfunc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -84,10 +84,11 @@ class MessageChecksumError(Exception):
|
|||||||
def decode_wire_stream(genfunc, session_id, *args):
|
def decode_wire_stream(genfunc, session_id, *args):
|
||||||
'''Decode a wire message from the report data and stream it to target.
|
'''Decode a wire message from the report data and stream it to target.
|
||||||
|
|
||||||
Receives report payloads.
|
Receives report payloads. After first report, creates target by calling
|
||||||
Sends (msg_type, data_len) to target, followed by data chunks.
|
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
|
||||||
Throws EOFError after last data chunk, in case of valid checksum.
|
data.
|
||||||
Throws MessageChecksumError to target if data doesn't match the checksum.
|
Throws `EOFError` to target after last data chunk, in case of valid checksum.
|
||||||
|
Throws `MessageChecksumError` to target if data doesn't match the checksum.
|
||||||
|
|
||||||
Pass report payloads as `memoryview` for cheaper slicing.
|
Pass report payloads as `memoryview` for cheaper slicing.
|
||||||
'''
|
'''
|
||||||
|
@ -2,7 +2,7 @@ from micropython import const
|
|||||||
import ustruct
|
import ustruct
|
||||||
|
|
||||||
SESSION_V1 = const(0)
|
SESSION_V1 = const(0)
|
||||||
REP_MARKER_V1 = const(63) # ord('?)
|
REP_MARKER_V1 = const(63) # ord('?')
|
||||||
REP_MARKER_V1_LEN = const(1) # len('?')
|
REP_MARKER_V1_LEN = const(1) # len('?')
|
||||||
|
|
||||||
_REP_LEN = const(64)
|
_REP_LEN = const(64)
|
||||||
@ -16,18 +16,23 @@ def detect_v1(data):
|
|||||||
|
|
||||||
|
|
||||||
def parse_report_v1(data):
|
def parse_report_v1(data):
|
||||||
|
if len(data) != _REP_LEN:
|
||||||
|
raise ValueError('Invalid buffer size')
|
||||||
return None, SESSION_V1, data[1:]
|
return None, SESSION_V1, data[1:]
|
||||||
|
|
||||||
|
|
||||||
def parse_message(data):
|
def parse_message(data):
|
||||||
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
|
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
|
||||||
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
|
||||||
raise Exception('Corrupted magic bytes')
|
raise ValueError('Corrupted magic bytes')
|
||||||
|
|
||||||
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
|
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_header(data, msg_type, msg_len):
|
def serialize_message_header(data, msg_type, msg_len):
|
||||||
|
if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:
|
||||||
|
raise ValueError('Invalid buffer size')
|
||||||
|
if msg_type < 0 or msg_type > 65535:
|
||||||
|
raise ValueError('Value is out of range')
|
||||||
ustruct.pack_into(
|
ustruct.pack_into(
|
||||||
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
|
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
|
||||||
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
|
||||||
@ -36,10 +41,10 @@ def serialize_message_header(data, msg_type, msg_len):
|
|||||||
def decode_wire_v1_stream(genfunc, session_id, *args):
|
def decode_wire_v1_stream(genfunc, session_id, *args):
|
||||||
'''Decode a v1 wire message from the report data and stream it to target.
|
'''Decode a v1 wire message from the report data and stream it to target.
|
||||||
|
|
||||||
Receives report payloads.
|
Receives report payloads. After first report, creates target by calling
|
||||||
Sends (msg_type, data_len) to target, followed by data chunks.
|
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
|
||||||
Throws EOFError after last data chunk, in case of valid checksum.
|
data.
|
||||||
Throws MessageChecksumError to target if data doesn't match the checksum.
|
Throws `EOFError` to target after last data chunk.
|
||||||
|
|
||||||
Pass report payloads as `memoryview` for cheaper slicing.
|
Pass report payloads as `memoryview` for cheaper slicing.
|
||||||
'''
|
'''
|
||||||
@ -47,7 +52,6 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
message = yield # read first report
|
message = yield # read first report
|
||||||
msg_type, data_len, data = parse_message(message)
|
msg_type, data_len, data = parse_message(message)
|
||||||
|
|
||||||
print(msg_type, data_len, bytes(data))
|
|
||||||
target = genfunc(msg_type, data_len, session_id, *args)
|
target = genfunc(msg_type, data_len, session_id, *args)
|
||||||
target.send(None)
|
target.send(None)
|
||||||
|
|
||||||
@ -65,6 +69,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
|
|||||||
|
|
||||||
|
|
||||||
def encode_wire_v1_message(msg_type, msg_data, target):
|
def encode_wire_v1_message(msg_type, msg_data, target):
|
||||||
|
'''Encode a full v1 wire message directly to reports and stream it to target.
|
||||||
|
|
||||||
|
Target receives `memoryview`s of HID reports which are valid until the targets
|
||||||
|
`send()` method returns.
|
||||||
|
'''
|
||||||
report = memoryview(bytearray(_REP_LEN))
|
report = memoryview(bytearray(_REP_LEN))
|
||||||
report[0] = REP_MARKER_V1
|
report[0] = REP_MARKER_V1
|
||||||
serialize_message_header(report, msg_type, len(msg_data))
|
serialize_message_header(report, msg_type, len(msg_data))
|
||||||
@ -79,7 +88,7 @@ def encode_wire_v1_message(msg_type, msg_data, target):
|
|||||||
source_data = source_data[n:]
|
source_data = source_data[n:]
|
||||||
target_data = target_data[n:]
|
target_data = target_data[n:]
|
||||||
|
|
||||||
# FIXME: optimize speed
|
# fill the rest of the report with 0x00
|
||||||
x = 0
|
x = 0
|
||||||
to_fill = len(target_data)
|
to_fill = len(target_data)
|
||||||
while x < to_fill:
|
while x < to_fill:
|
||||||
|
Loading…
Reference in New Issue
Block a user