From d0b29d4caae0ba5e77460b4ec29e972217e9fa7b Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Fri, 21 Oct 2016 15:01:49 +0200 Subject: [PATCH] wire: add tests, fix missing 0-padding --- src/tests/test_wire.py.disabled | 87 ----------- src/tests/test_wire_codec.py | 250 ++++++++++++++++++++++++++++++++ src/trezor/wire/wire_codec.py | 21 ++- 3 files changed, 269 insertions(+), 89 deletions(-) delete mode 100644 src/tests/test_wire.py.disabled create mode 100644 src/tests/test_wire_codec.py diff --git a/src/tests/test_wire.py.disabled b/src/tests/test_wire.py.disabled deleted file mode 100644 index b17a568260..0000000000 --- a/src/tests/test_wire.py.disabled +++ /dev/null @@ -1,87 +0,0 @@ -import sys -sys.path.append('..') -sys.path.append('../lib') -import unittest - -from trezor import loop -from trezor import msg -from trezor.wire import read_wire_msg, write_wire_msg -from trezor.utils import chunks - - -class TestWire(unittest.TestCase): - - def test_read_wire_msg(self): - - # Reading empty message returns correct type and empty bytes - - reader = read_wire_msg() - reader.send(None) - - empty_message = b'\x3f##\xab\xcd\x00\x00\x00\x00' + b'\x00' * 55 - try: - reader.send((empty_message,)) - except StopIteration as e: - restype, resmsg = e.value - self.assertEqual(restype, int('0xabcd', 16)) - self.assertEqual(resmsg, b'') - - # Reading message from one report - - reader = read_wire_msg() - reader.send(None) - - content = bytes([x for x in range(0, 55)]) - message = b'\x3f##\xab\xcd\x00\x00\x00\x37' + content - try: - reader.send((message,)) - except StopIteration as e: - restype, resmsg = e.value - self.assertEqual(restype, int('0xabcd', 16)) - self.assertEqual(resmsg, content) - - # Reading message spanning multiple reports - - reader = read_wire_msg() - reader.send(None) - - content = bytes([x for x in range(0, 256)]) - message = b'##\xab\xcd\x00\x00\x01\00' + content - reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] - try: - for report in reports: - reader.send((report,)) - except StopIteration as e: - restype, resmsg = e.value - self.assertEqual(restype, int('0xabcd', 16)) - self.assertEqual(resmsg, content) - - def test_write_wire_msg(self): - - # Writing message spanning multiple reports calls msg.send() with correct data - - sent_reps = [] - - def dummy_send(iface, rep): - sent_reps.append(bytes(rep)) - return len(rep) - - msg.send = dummy_send - - content = bytes([x for x in range(0, 256)]) - message = b'##\xab\xcd\x00\x00\x01\00' + content - reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] - - writer = write_wire_msg(int('0xabcd'), content) - res = 1 # Something not None - try: - while True: - writer.send(None) - except StopIteration as e: - res = e.value - self.assertEqual(res, None) - self.assertEqual(sent_reps, reports) - - -if __name__ == '__main__': - unittest.main() diff --git a/src/tests/test_wire_codec.py b/src/tests/test_wire_codec.py new file mode 100644 index 0000000000..9ec5626e6d --- /dev/null +++ b/src/tests/test_wire_codec.py @@ -0,0 +1,250 @@ +import sys +sys.path.append('..') +sys.path.append('../lib') +import unittest +import ustruct +import ubinascii + +from trezor.wire import wire_codec +from trezor.utils import chunks +from trezor.crypto import random + + +class TestWireCodec(unittest.TestCase): + # pylint: disable=C0301 + + def test_parse(self): + d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) + + m, s, d = wire_codec.parse_report(d) + self.assertEqual(m, b'O'[0]) + self.assertEqual(s, 0x01234567) + self.assertEqual(d, bytes(range(0, 59))) + + t, l, d = wire_codec.parse_message(d) + self.assertEqual(t, 0x00010203) + self.assertEqual(l, 0x04050607) + self.assertEqual(d, bytes(range(8, 59))) + + f, = wire_codec.parse_message_footer(d[0:4]) + self.assertEqual(f, 0x08090a0b) + + for i in range(0, 1024): + if i != 64: + with self.assertRaises(ValueError): + wire_codec.parse_report(bytes(range(0, i))) + if i != 59: + with self.assertRaises(ValueError): + wire_codec.parse_message(bytes(range(0, i))) + if i != 4: + with self.assertRaises(ValueError): + wire_codec.parse_message_footer(bytes(range(0, i))) + + def test_serialize(self): + data = bytearray(range(0, 6)) + wire_codec.serialize_report_header(data, 0x12, 0x3456789a) + self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') + + data = bytearray(range(0, 6)) + wire_codec.serialize_opened_session(data, 0x3456789a) + self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') + + data = bytearray(range(0, 14)) + wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef) + self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') + + data = bytearray(range(0, 5)) + wire_codec.serialize_message_footer(data, 0x89abcdef) + self.assertEqual(data, b'\x89\xab\xcd\xef\x04') + + for i in range(0, 13): + data = bytearray(i) + if i < 4: + with self.assertRaises(ValueError): + wire_codec.serialize_message_footer(data, 0x00) + if i < 5: + with self.assertRaises(ValueError): + wire_codec.serialize_report_header(data, 0x00, 0x00) + with self.assertRaises(ValueError): + wire_codec.serialize_opened_session(data, 0x00) + with self.assertRaises(ValueError): + wire_codec.serialize_message_header(data, 0x00, 0x00) + + def test_decode_empty(self): + message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 + + record = [] + genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy') + decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + try: + decoder.send(message) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), 1) + self.assertIsInstance(record[0], EOFError) + + def test_decode_one_report_aligned_correct(self): + data = bytes(range(0, 47)) + footer = b'\x2f\x1c\x12\xce' + message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer + + record = [] + genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') + decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + try: + decoder.send(message) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), 2) + self.assertEqual(record[0], data) + self.assertIsInstance(record[1], EOFError) + + def test_decode_one_report_aligned_incorrect(self): + data = bytes(range(0, 47)) + footer = bytes(4) # wrong checksum + message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x2f' + data + footer + + record = [] + genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') + decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + try: + decoder.send(message) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), 2) + self.assertEqual(record[0], data) + self.assertIsInstance(record[1], wire_codec.MessageChecksumError) + + def test_decode_generated_range(self): + for data_len in range(1, 512): + data = random.bytes(data_len) + data_chunks = [data[:51]] + list(chunks(data[51:], 59)) + + msg_type = 0xabcdef12 + data_csum = ubinascii.crc32(data) + header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len) + footer = ustruct.pack('>L', data_csum) + + message = header + data + footer + message_chunks = [c + '\x00' * (59 - len(c)) for c in list(chunks(message, 59))] + + record = [] + genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') + decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder.send(None) + + res = 1 + try: + for c in message_chunks: + decoder.send(c) + except StopIteration as e: + res = e.value + self.assertEqual(res, None) + self.assertEqual(len(record), len(data_chunks) + 1) + for i in range(0, len(data_chunks)): + self.assertEqual(record[i], data_chunks[i]) + self.assertIsInstance(record[-1], EOFError) + + def test_encode_empty(self): + record = [] + target = self._record(record)() + target.send(None) + + wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) + self.assertEqual(len(record), 1) + self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) + + def test_encode_one_report_aligned(self): + data = bytes(range(0, 47)) + footer = b'\x2f\x1c\x12\xce' + + record = [] + target = self._record(record)() + target.send(None) + + wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) + self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) + + def test_encode_generated_range(self): + for data_len in range(1, 1024): + data = random.bytes(data_len) + + msg_type = 0xabcdef12 + session_id = 0xdeadbeef + + data_csum = ubinascii.crc32(data) + header = ustruct.pack('>L', msg_type) + ustruct.pack('>L', data_len) + footer = ustruct.pack('>L', data_csum) + session_header = ustruct.pack('>L', session_id) + + message = header + data + footer + report0 = b'H' + session_header + message[:59] + reports = [b'D' + session_header + c for c in chunks(message[59:], 59)] + reports.insert(0, report0) + reports[-1] = reports[-1] + b'\x00' * (64 - len(reports[-1])) + + received = 0 + def genfunc(): + nonlocal received + while True: + self.assertEqual((yield), reports[received]) + received += 1 + target = genfunc() + target.send(None) + + wire_codec.encode_wire_message(msg_type, data, session_id, target) + self.assertEqual(received, len(reports)) + + def _record(self, record, *_args): + def genfunc(*args): + self.assertEqual(args, _args) + while True: + try: + v = yield + except Exception as e: + record.append(e) + else: + record.append(v) + return genfunc + + +# def test_write_wire_msg(self): + +# # Writing message spanning multiple reports calls msg.send() with +# # correct data + +# sent_reps = [] + +# def dummy_send(iface, rep): +# sent_reps.append(bytes(rep)) +# return len(rep) + +# msg.send = dummy_send + +# content = bytes([x for x in range(0, 256)]) +# message = b'##\xab\xcd\x00\x00\x01\00' + content +# reports = [b'\x3f' + ch + '\x00' * +# (63 - len(ch)) for ch in chunks(message, 63)] + +# writer = write_wire_msg(int('0xabcd'), content) +# res = 1 # Something not None +# try: +# while True: +# writer.send(None) +# except StopIteration as e: +# res = e.value +# self.assertEqual(res, None) +# self.assertEqual(sent_reps, reports) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/wire_codec.py index 096b775efc..af03e4414b 100644 --- a/src/trezor/wire/wire_codec.py +++ b/src/trezor/wire/wire_codec.py @@ -35,29 +35,41 @@ _MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER) def parse_report(data): + if len(data) != _REP_LEN: + raise ValueError('Invalid buffer size') marker, session_id = ustruct.unpack(_REP_HEADER, data) return marker, session_id, data[_REP_HEADER_LEN:] def parse_message(data): + if len(data) != _REP_LEN - _REP_HEADER_LEN: + raise ValueError('Invalid buffer size') msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) return msg_type, data_len, data[_MSG_HEADER_LEN:] def parse_message_footer(data): + if len(data) != _MSG_FOOTER_LEN: + raise ValueError('Invalid buffer size') data_checksum, = ustruct.unpack(_MSG_FOOTER, data) return data_checksum, def serialize_report_header(data, marker, session_id): + if len(data) < _REP_HEADER_LEN: + raise ValueError('Invalid buffer size') ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id) def serialize_message_header(data, msg_type, msg_len): + if len(data) < _REP_HEADER_LEN + _MSG_HEADER_LEN: + raise ValueError('Invalid buffer size') ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len) def serialize_message_footer(data, checksum): + if len(data) < _MSG_FOOTER_LEN: + raise ValueError('Invalid buffer size') ustruct.pack_into(_MSG_FOOTER, data, 0, checksum) @@ -112,6 +124,11 @@ Pass report payloads as `memoryview` for cheaper slicing. def encode_wire_message(msg_type, msg_data, session_id, target): + '''Encode a full wire message directly to reports and stream it to target. + +Target receives `memoryview`s of HID reports which are valid until the targets +`send()` method returns. + ''' report = memoryview(bytearray(_REP_LEN)) serialize_report_header(report, REP_MARKER_HEADER, session_id) serialize_message_header(report, msg_type, len(msg_data)) @@ -139,7 +156,7 @@ def encode_wire_message(msg_type, msg_data, session_id, target): msg_footer = None continue - # FIXME: optimize speed + # fill the rest of the report with 0x00 x = 0 to_fill = len(target_data) while x < to_fill: @@ -154,8 +171,8 @@ def encode_wire_message(msg_type, msg_data, session_id, target): # reset to skip the magic and session ID if first: serialize_report_header(report, REP_MARKER_DATA, session_id) - target_data = report[_REP_HEADER_LEN:] first = False + target_data = report[_REP_HEADER_LEN:] def encode_session_open_message(session_id, target):