1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 22:40:59 +00:00

wire: add tests for v1 codec

This commit is contained in:
Jan Pochyla 2016-10-24 13:29:43 +02:00
parent 7404a76aeb
commit 507d9bdf68
3 changed files with 203 additions and 13 deletions

View File

@ -0,0 +1,180 @@
import sys
sys.path.append('..')
sys.path.append('../lib')
import unittest
import ustruct
from trezor.wire import wire_codec_v1
from trezor.utils import chunks
from trezor.crypto import random
class TestWireCodecV1(unittest.TestCase):
# pylint: disable=C0301
def test_detect(self):
for i in range(0, 256):
if i == ord(b'?'):
self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
else:
self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
def test_parse(self):
d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m
rm, rs, rd = wire_codec_v1.parse_report_v1(r)
self.assertEqual(rm, None)
self.assertEqual(rs, 0)
self.assertEqual(rd, m)
mt, ml, md = wire_codec_v1.parse_message(m)
self.assertEqual(mt, 0)
self.assertEqual(ml, len(d))
self.assertEqual(md, d)
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
wire_codec_v1.parse_report_v1(bytes(range(0, i)))
for hx in range(0, 256):
for hy in range(0, 256):
if hx != ord(b'#') and hy != ord(b'#'):
with self.assertRaises(ValueError):
wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:])
def test_serialize(self):
data = bytearray(range(0, 10))
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
data = bytearray(9)
with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 65536, 0)
for i in range(0, 8):
data = bytearray(i)
with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
def test_decode_empty(self):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
record = []
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 1)
self.assertIsInstance(record[0], EOFError)
def test_decode_one_report_aligned(self):
data = bytes(range(0, 55))
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data
record = []
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
decoder.send(message)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], EOFError)
def test_decode_generated_range(self):
for data_len in range(1, 512):
data = random.bytes(data_len)
data_chunks = [data[:55]] + list(chunks(data[55:], 63))
msg_type = 0xabcd
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
message = header + data
message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))]
record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
res = 1
try:
for c in message_chunks:
decoder.send(c)
except StopIteration as e:
res = e.value
self.assertEqual(res, None)
self.assertEqual(len(record), len(data_chunks) + 1)
for i in range(0, len(data_chunks)):
self.assertEqual(record[i], data_chunks[i])
self.assertIsInstance(record[-1], EOFError)
def test_encode_empty(self):
record = []
target = self._record(record)()
target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
def test_encode_one_report_aligned(self):
data = bytes(range(0, 55))
record = []
target = self._record(record)()
target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, data, target)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self):
for data_len in range(1, 1024):
data = random.bytes(data_len)
msg_type = 0xabcd
header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len)
message = header + data
reports = [b'?' + c for c in chunks(message, 63)]
reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1]))
received = 0
def genfunc():
nonlocal received
while True:
self.assertEqual((yield), reports[received])
received += 1
target = genfunc()
target.send(None)
wire_codec_v1.encode_wire_v1_message(msg_type, data, target)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
def genfunc(*args):
self.assertEqual(args, _args)
while True:
try:
v = yield
except Exception as e:
record.append(e)
else:
record.append(v)
return genfunc
if __name__ == '__main__':
unittest.main()

View File

@ -84,10 +84,11 @@ class MessageChecksumError(Exception):
def decode_wire_stream(genfunc, session_id, *args):
'''Decode a wire message from the report data and stream it to target.
Receives report payloads.
Sends (msg_type, data_len) to target, followed by data chunks.
Throws EOFError after last data chunk, in case of valid checksum.
Throws MessageChecksumError to target if data doesn't match the checksum.
Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk, in case of valid checksum.
Throws `MessageChecksumError` to target if data doesn't match the checksum.
Pass report payloads as `memoryview` for cheaper slicing.
'''

View File

@ -2,7 +2,7 @@ from micropython import const
import ustruct
SESSION_V1 = const(0)
REP_MARKER_V1 = const(63) # ord('?)
REP_MARKER_V1 = const(63) # ord('?')
REP_MARKER_V1_LEN = const(1) # len('?')
_REP_LEN = const(64)
@ -16,18 +16,23 @@ def detect_v1(data):
def parse_report_v1(data):
if len(data) != _REP_LEN:
raise ValueError('Invalid buffer size')
return None, SESSION_V1, data[1:]
def parse_message(data):
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise Exception('Corrupted magic bytes')
raise ValueError('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
def serialize_message_header(data, msg_type, msg_len):
if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:
raise ValueError('Invalid buffer size')
if msg_type < 0 or msg_type > 65535:
raise ValueError('Value is out of range')
ustruct.pack_into(
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
@ -36,10 +41,10 @@ def serialize_message_header(data, msg_type, msg_len):
def decode_wire_v1_stream(genfunc, session_id, *args):
'''Decode a v1 wire message from the report data and stream it to target.
Receives report payloads.
Sends (msg_type, data_len) to target, followed by data chunks.
Throws EOFError after last data chunk, in case of valid checksum.
Throws MessageChecksumError to target if data doesn't match the checksum.
Receives report payloads. After first report, creates target by calling
`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message
data.
Throws `EOFError` to target after last data chunk.
Pass report payloads as `memoryview` for cheaper slicing.
'''
@ -47,7 +52,6 @@ Pass report payloads as `memoryview` for cheaper slicing.
message = yield # read first report
msg_type, data_len, data = parse_message(message)
print(msg_type, data_len, bytes(data))
target = genfunc(msg_type, data_len, session_id, *args)
target.send(None)
@ -65,6 +69,11 @@ Pass report payloads as `memoryview` for cheaper slicing.
def encode_wire_v1_message(msg_type, msg_data, target):
'''Encode a full v1 wire message directly to reports and stream it to target.
Target receives `memoryview`s of HID reports which are valid until the targets
`send()` method returns.
'''
report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER_V1
serialize_message_header(report, msg_type, len(msg_data))
@ -79,7 +88,7 @@ def encode_wire_v1_message(msg_type, msg_data, target):
source_data = source_data[n:]
target_data = target_data[n:]
# FIXME: optimize speed
# fill the rest of the report with 0x00
x = 0
to_fill = len(target_data)
while x < to_fill: