From 507d9bdf68716110f17b47300784f1e5177e59d5 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Mon, 24 Oct 2016 13:29:43 +0200 Subject: [PATCH] wire: add tests for v1 codec --- src/tests/test_wire_codec_v1.py | 180 +++++++++++++++++++++++++++++++ src/trezor/wire/wire_codec.py | 9 +- src/trezor/wire/wire_codec_v1.py | 27 +++-- 3 files changed, 203 insertions(+), 13 deletions(-) create mode 100644 src/tests/test_wire_codec_v1.py diff --git a/src/tests/test_wire_codec_v1.py b/src/tests/test_wire_codec_v1.py new file mode 100644 index 0000000000..e08729fecd --- /dev/null +++ b/src/tests/test_wire_codec_v1.py @@ -0,0 +1,180 @@ +import sys +sys.path.append('..') +sys.path.append('../lib') +import unittest +import ustruct + +from trezor.wire import wire_codec_v1 +from trezor.utils import chunks +from trezor.crypto import random + + +class TestWireCodecV1(unittest.TestCase): + # pylint: disable=C0301 + + def test_detect(self): + for i in range(0, 256): + if i == ord(b'?'): + self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + else: + self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + + def test_parse(self): + d = bytes(range(0, 55)) + m = b'##\x00\x00\x00\x00\x00\x37' + d + r = b'?' + m + + rm, rs, rd = wire_codec_v1.parse_report_v1(r) + self.assertEqual(rm, None) + self.assertEqual(rs, 0) + self.assertEqual(rd, m) + + mt, ml, md = wire_codec_v1.parse_message(m) + self.assertEqual(mt, 0) + self.assertEqual(ml, len(d)) + self.assertEqual(md, d) + + for i in range(0, 1024): + if i != 64: + with self.assertRaises(ValueError): + wire_codec_v1.parse_report_v1(bytes(range(0, i))) + + for hx in range(0, 256): + for hy in range(0, 256): + if hx != ord(b'#') and hy != ord(b'#'): + with self.assertRaises(ValueError): + wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:]) + + def test_serialize(self): + data = bytearray(range(0, 10)) + wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) + self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09') + + data = bytearray(9) + with self.assertRaises(ValueError): + wire_codec_v1.serialize_message_header(data, 65536, 0) + + for i in range(0, 8): + data = bytearray(i) + with self.assertRaises(ValueError): + wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) + + def test_decode_empty(self): + message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 + + record = [] + genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy') + decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + try: + decoder.send(message) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), 1) + self.assertIsInstance(record[0], EOFError) + + def test_decode_one_report_aligned(self): + data = bytes(range(0, 55)) + message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x37' + data + + record = [] + genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy') + decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + try: + decoder.send(message) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), 2) + self.assertEqual(record[0], data) + self.assertIsInstance(record[1], EOFError) + + def test_decode_generated_range(self): + for data_len in range(1, 512): + data = random.bytes(data_len) + data_chunks = [data[:55]] + list(chunks(data[55:], 63)) + + msg_type = 0xabcd + header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len) + + message = header + data + message_chunks = [c + '\x00' * (63 - len(c)) for c in list(chunks(message, 63))] + + record = [] + genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') + decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + res = 1 + try: + for c in message_chunks: + decoder.send(c) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), len(data_chunks) + 1) + for i in range(0, len(data_chunks)): + self.assertEqual(record[i], data_chunks[i]) + self.assertIsInstance(record[-1], EOFError) + + def test_encode_empty(self): + record = [] + target = self._record(record)() + target.send(None) + + wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target) + self.assertEqual(len(record), 1) + self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) + + def test_encode_one_report_aligned(self): + data = bytes(range(0, 55)) + + record = [] + target = self._record(record)() + target.send(None) + + wire_codec_v1.encode_wire_v1_message(0xabcd, data, target) + self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) + + def test_encode_generated_range(self): + for data_len in range(1, 1024): + data = random.bytes(data_len) + + msg_type = 0xabcd + header = b'##' + ustruct.pack('>H', msg_type) + ustruct.pack('>L', data_len) + + message = header + data + reports = [b'?' + c for c in chunks(message, 63)] + reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1])) + + received = 0 + def genfunc(): + nonlocal received + while True: + self.assertEqual((yield), reports[received]) + received += 1 + target = genfunc() + target.send(None) + + wire_codec_v1.encode_wire_v1_message(msg_type, data, target) + self.assertEqual(received, len(reports)) + + def _record(self, record, *_args): + def genfunc(*args): + self.assertEqual(args, _args) + while True: + try: + v = yield + except Exception as e: + record.append(e) + else: + record.append(v) + return genfunc + + +if __name__ == '__main__': + unittest.main() diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/wire_codec.py index af03e4414b..22a2669a8b 100644 --- a/src/trezor/wire/wire_codec.py +++ b/src/trezor/wire/wire_codec.py @@ -84,10 +84,11 @@ class MessageChecksumError(Exception): def decode_wire_stream(genfunc, session_id, *args): '''Decode a wire message from the report data and stream it to target. -Receives report payloads. -Sends (msg_type, data_len) to target, followed by data chunks. -Throws EOFError after last data chunk, in case of valid checksum. -Throws MessageChecksumError to target if data doesn't match the checksum. +Receives report payloads. After first report, creates target by calling +`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message +data. +Throws `EOFError` to target after last data chunk, in case of valid checksum. +Throws `MessageChecksumError` to target if data doesn't match the checksum. Pass report payloads as `memoryview` for cheaper slicing. ''' diff --git a/src/trezor/wire/wire_codec_v1.py b/src/trezor/wire/wire_codec_v1.py index 9b7ef69e7a..8e6b0ac6b0 100644 --- a/src/trezor/wire/wire_codec_v1.py +++ b/src/trezor/wire/wire_codec_v1.py @@ -2,7 +2,7 @@ from micropython import const import ustruct SESSION_V1 = const(0) -REP_MARKER_V1 = const(63) # ord('?) +REP_MARKER_V1 = const(63) # ord('?') REP_MARKER_V1_LEN = const(1) # len('?') _REP_LEN = const(64) @@ -16,18 +16,23 @@ def detect_v1(data): def parse_report_v1(data): + if len(data) != _REP_LEN: + raise ValueError('Invalid buffer size') return None, SESSION_V1, data[1:] def parse_message(data): magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data) if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: - raise Exception('Corrupted magic bytes') - + raise ValueError('Corrupted magic bytes') return msg_type, data_len, data[_MSG_HEADER_V1_LEN:] def serialize_message_header(data, msg_type, msg_len): + if len(data) < REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN: + raise ValueError('Invalid buffer size') + if msg_type < 0 or msg_type > 65535: + raise ValueError('Value is out of range') ustruct.pack_into( _MSG_HEADER_V1, data, REP_MARKER_V1_LEN, _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) @@ -36,10 +41,10 @@ def serialize_message_header(data, msg_type, msg_len): def decode_wire_v1_stream(genfunc, session_id, *args): '''Decode a v1 wire message from the report data and stream it to target. -Receives report payloads. -Sends (msg_type, data_len) to target, followed by data chunks. -Throws EOFError after last data chunk, in case of valid checksum. -Throws MessageChecksumError to target if data doesn't match the checksum. +Receives report payloads. After first report, creates target by calling +`genfunc(msg_type, data_len, session_id, *args)` and sends chunks of message +data. +Throws `EOFError` to target after last data chunk. Pass report payloads as `memoryview` for cheaper slicing. ''' @@ -47,7 +52,6 @@ Pass report payloads as `memoryview` for cheaper slicing. message = yield # read first report msg_type, data_len, data = parse_message(message) - print(msg_type, data_len, bytes(data)) target = genfunc(msg_type, data_len, session_id, *args) target.send(None) @@ -65,6 +69,11 @@ Pass report payloads as `memoryview` for cheaper slicing. def encode_wire_v1_message(msg_type, msg_data, target): + '''Encode a full v1 wire message directly to reports and stream it to target. + +Target receives `memoryview`s of HID reports which are valid until the targets +`send()` method returns. + ''' report = memoryview(bytearray(_REP_LEN)) report[0] = REP_MARKER_V1 serialize_message_header(report, msg_type, len(msg_data)) @@ -79,7 +88,7 @@ def encode_wire_v1_message(msg_type, msg_data, target): source_data = source_data[n:] target_data = target_data[n:] - # FIXME: optimize speed + # fill the rest of the report with 0x00 x = 0 to_fill = len(target_data) while x < to_fill: