diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index c44805b73..4ca6c66a9 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -6,12 +6,12 @@ from trezor.messages import get_protobuf_type from trezor.workflow import start_workflow from trezor import log -from .wire_io import read_report_stream, write_report_stream -from .wire_dispatcher import dispatch_reports_by_session -from .wire_codec import \ +from .io import read_report_stream, write_report_stream +from .dispatcher import dispatch_reports_by_session +from .codec import \ decode_wire_stream, encode_wire_message, \ encode_session_open_message, encode_session_close_message -from .wire_codec_v1 import \ +from .codec_v1 import \ SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message _session_handlers = {} # session id -> generator diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/codec.py similarity index 100% rename from src/trezor/wire/wire_codec.py rename to src/trezor/wire/codec.py diff --git a/src/trezor/wire/wire_codec_v1.py b/src/trezor/wire/codec_v1.py similarity index 100% rename from src/trezor/wire/wire_codec_v1.py rename to src/trezor/wire/codec_v1.py diff --git a/src/trezor/wire/wire_dispatcher.py b/src/trezor/wire/dispatcher.py similarity index 91% rename from src/trezor/wire/wire_dispatcher.py rename to src/trezor/wire/dispatcher.py index 771254e4d..55eae3d42 100644 --- a/src/trezor/wire/wire_dispatcher.py +++ b/src/trezor/wire/dispatcher.py @@ -1,6 +1,6 @@ from trezor import log -from .wire_codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE -from .wire_codec_v1 import detect_v1, parse_report_v1 +from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE +from .codec_v1 import detect_v1, parse_report_v1 def dispatch_reports_by_session(handlers, diff --git a/src/trezor/wire/wire_io.py b/src/trezor/wire/io.py similarity index 88% rename from src/trezor/wire/wire_io.py rename to src/trezor/wire/io.py index fd371a46e..e148ca5fb 100644 --- a/src/trezor/wire/wire_io.py +++ b/src/trezor/wire/io.py @@ -1,5 +1,5 @@ -from ubinascii import hexlify from micropython import const +from ubinascii import hexlify from trezor import msg, loop, log _DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface @@ -15,5 +15,5 @@ def read_report_stream(target, iface=_DEFAULT_IFACE): def write_report_stream(iface=_DEFAULT_IFACE): while True: report = yield - log.debug(__name__, 'write report %s', hexlify(report)) + log.info(__name__, 'write report %s', hexlify(report)) msg.send(iface, report) diff --git a/tests/test_trezor.wire.wire_codec.py b/tests/test_trezor.wire.codec.py similarity index 79% rename from tests/test_trezor.wire.wire_codec.py rename to tests/test_trezor.wire.codec.py index cd989d658..00599ebcf 100644 --- a/tests/test_trezor.wire.wire_codec.py +++ b/tests/test_trezor.wire.codec.py @@ -6,7 +6,7 @@ import ubinascii from trezor.crypto import random from trezor.utils import chunks -from trezor.wire import wire_codec +from trezor.wire import codec class TestWireCodec(unittest.TestCase): # pylint: disable=C0301 @@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase): def test_parse(self): d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) - m, s, d = wire_codec.parse_report(d) + m, s, d = codec.parse_report(d) self.assertEqual(m, b'O'[0]) self.assertEqual(s, 0x01234567) self.assertEqual(d, bytes(range(0, 59))) - t, l, d = wire_codec.parse_message(d) + t, l, d = codec.parse_message(d) self.assertEqual(t, 0x00010203) self.assertEqual(l, 0x04050607) self.assertEqual(d, bytes(range(8, 59))) - f, = wire_codec.parse_message_footer(d[0:4]) + f, = codec.parse_message_footer(d[0:4]) self.assertEqual(f, 0x08090a0b) for i in range(0, 1024): if i != 64: with self.assertRaises(ValueError): - wire_codec.parse_report(bytes(range(0, i))) + codec.parse_report(bytes(range(0, i))) if i != 59: with self.assertRaises(ValueError): - wire_codec.parse_message(bytes(range(0, i))) + codec.parse_message(bytes(range(0, i))) if i != 4: with self.assertRaises(ValueError): - wire_codec.parse_message_footer(bytes(range(0, i))) + codec.parse_message_footer(bytes(range(0, i))) def test_serialize(self): data = bytearray(range(0, 6)) - wire_codec.serialize_report_header(data, 0x12, 0x3456789a) + codec.serialize_report_header(data, 0x12, 0x3456789a) self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') data = bytearray(range(0, 6)) - wire_codec.serialize_opened_session(data, 0x3456789a) - self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') + codec.serialize_opened_session(data, 0x3456789a) + self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') data = bytearray(range(0, 14)) - wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef) + codec.serialize_message_header(data, 0x01234567, 0x89abcdef) self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') data = bytearray(range(0, 5)) - wire_codec.serialize_message_footer(data, 0x89abcdef) + codec.serialize_message_footer(data, 0x89abcdef) self.assertEqual(data, b'\x89\xab\xcd\xef\x04') for i in range(0, 13): data = bytearray(i) if i < 4: with self.assertRaises(ValueError): - wire_codec.serialize_message_footer(data, 0x00) + codec.serialize_message_footer(data, 0x00) if i < 5: with self.assertRaises(ValueError): - wire_codec.serialize_report_header(data, 0x00, 0x00) + codec.serialize_report_header(data, 0x00, 0x00) with self.assertRaises(ValueError): - wire_codec.serialize_opened_session(data, 0x00) + codec.serialize_opened_session(data, 0x00) with self.assertRaises(ValueError): - wire_codec.serialize_message_header(data, 0x00, 0x00) + codec.serialize_message_header(data, 0x00, 0x00) def test_decode_empty(self): message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 record = [] genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy') - decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) try: @@ -91,7 +91,7 @@ class TestWireCodec(unittest.TestCase): record = [] genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') - decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) try: @@ -110,7 +110,7 @@ class TestWireCodec(unittest.TestCase): record = [] genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') - decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) try: @@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase): self.assertEqual(res, None) self.assertEqual(len(record), 2) self.assertEqual(record[0], data) - self.assertIsInstance(record[1], wire_codec.MessageChecksumError) + self.assertIsInstance(record[1], codec.MessageChecksumError) def test_decode_generated_range(self): for data_len in range(1, 512): @@ -137,7 +137,7 @@ class TestWireCodec(unittest.TestCase): record = [] genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') - decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) res = 1 @@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase): target = self._record(record)() target.send(None) - wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) + codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) self.assertEqual(len(record), 1) self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) @@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase): target = self._record(record)() target.send(None) - wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) + codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) def test_encode_generated_range(self): @@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase): target = genfunc() target.send(None) - wire_codec.encode_wire_message(msg_type, data, session_id, target) + codec.encode_wire_message(msg_type, data, session_id, target) self.assertEqual(received, len(reports)) def _record(self, record, *_args): diff --git a/tests/test_trezor.wire.wire_codec_v1.py b/tests/test_trezor.wire.codec_v1.py similarity index 80% rename from tests/test_trezor.wire.wire_codec_v1.py rename to tests/test_trezor.wire.codec_v1.py index 10a224f06..73b377b39 100644 --- a/tests/test_trezor.wire.wire_codec_v1.py +++ b/tests/test_trezor.wire.codec_v1.py @@ -1,12 +1,11 @@ from common import * import ustruct -import ubinascii from trezor.crypto import random from trezor.utils import chunks -from trezor.wire import wire_codec_v1 +from trezor.wire import codec_v1 class TestWireCodecV1(unittest.TestCase): # pylint: disable=C0301 @@ -14,21 +13,21 @@ class TestWireCodecV1(unittest.TestCase): def test_detect(self): for i in range(0, 256): if i == ord(b'?'): - self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) else: - self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) + self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) def test_parse(self): d = bytes(range(0, 55)) m = b'##\x00\x00\x00\x00\x00\x37' + d r = b'?' + m - rm, rs, rd = wire_codec_v1.parse_report_v1(r) + rm, rs, rd = codec_v1.parse_report_v1(r) self.assertEqual(rm, None) self.assertEqual(rs, 0) self.assertEqual(rd, m) - mt, ml, md = wire_codec_v1.parse_message(m) + mt, ml, md = codec_v1.parse_message(m) self.assertEqual(mt, 0) self.assertEqual(ml, len(d)) self.assertEqual(md, d) @@ -36,34 +35,34 @@ class TestWireCodecV1(unittest.TestCase): for i in range(0, 1024): if i != 64: with self.assertRaises(ValueError): - wire_codec_v1.parse_report_v1(bytes(range(0, i))) + codec_v1.parse_report_v1(bytes(range(0, i))) for hx in range(0, 256): for hy in range(0, 256): if hx != ord(b'#') and hy != ord(b'#'): with self.assertRaises(ValueError): - wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:]) + codec_v1.parse_message(bytes([hx, hy]) + m[2:]) def test_serialize(self): data = bytearray(range(0, 10)) - wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) + codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09') data = bytearray(9) with self.assertRaises(ValueError): - wire_codec_v1.serialize_message_header(data, 65536, 0) + codec_v1.serialize_message_header(data, 65536, 0) for i in range(0, 8): data = bytearray(i) with self.assertRaises(ValueError): - wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) + codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) def test_decode_empty(self): message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 record = [] genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy') - decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) try: @@ -80,7 +79,7 @@ class TestWireCodecV1(unittest.TestCase): record = [] genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy') - decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) try: @@ -105,7 +104,7 @@ class TestWireCodecV1(unittest.TestCase): record = [] genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') - decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') + decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder.send(None) res = 1 @@ -125,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase): target = self._record(record)() target.send(None) - wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target) + codec_v1.encode_wire_v1_message(0xabcd, b'', target) self.assertEqual(len(record), 1) self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) @@ -136,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase): target = self._record(record)() target.send(None) - wire_codec_v1.encode_wire_v1_message(0xabcd, data, target) + codec_v1.encode_wire_v1_message(0xabcd, data, target) self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) def test_encode_generated_range(self): @@ -159,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase): target = genfunc() target.send(None) - wire_codec_v1.encode_wire_v1_message(msg_type, data, target) + codec_v1.encode_wire_v1_message(msg_type, data, target) self.assertEqual(received, len(reports)) def _record(self, record, *_args):