trezor.wire: rename modules

pull/25/head
Jan Pochyla 8 years ago
parent b145f8f309
commit 1b27bb480d

@ -6,12 +6,12 @@ from trezor.messages import get_protobuf_type
from trezor.workflow import start_workflow from trezor.workflow import start_workflow
from trezor import log from trezor import log
from .wire_io import read_report_stream, write_report_stream from .io import read_report_stream, write_report_stream
from .wire_dispatcher import dispatch_reports_by_session from .dispatcher import dispatch_reports_by_session
from .wire_codec import \ from .codec import \
decode_wire_stream, encode_wire_message, \ decode_wire_stream, encode_wire_message, \
encode_session_open_message, encode_session_close_message encode_session_open_message, encode_session_close_message
from .wire_codec_v1 import \ from .codec_v1 import \
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
_session_handlers = {} # session id -> generator _session_handlers = {} # session id -> generator

@ -1,6 +1,6 @@
from trezor import log from trezor import log
from .wire_codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
from .wire_codec_v1 import detect_v1, parse_report_v1 from .codec_v1 import detect_v1, parse_report_v1
def dispatch_reports_by_session(handlers, def dispatch_reports_by_session(handlers,

@ -1,5 +1,5 @@
from ubinascii import hexlify
from micropython import const from micropython import const
from ubinascii import hexlify
from trezor import msg, loop, log from trezor import msg, loop, log
_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface _DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface
@ -15,5 +15,5 @@ def read_report_stream(target, iface=_DEFAULT_IFACE):
def write_report_stream(iface=_DEFAULT_IFACE): def write_report_stream(iface=_DEFAULT_IFACE):
while True: while True:
report = yield report = yield
log.debug(__name__, 'write report %s', hexlify(report)) log.info(__name__, 'write report %s', hexlify(report))
msg.send(iface, report) msg.send(iface, report)

@ -6,7 +6,7 @@ import ubinascii
from trezor.crypto import random from trezor.crypto import random
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import wire_codec from trezor.wire import codec
class TestWireCodec(unittest.TestCase): class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301 # pylint: disable=C0301
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
def test_parse(self): def test_parse(self):
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59)) d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
m, s, d = wire_codec.parse_report(d) m, s, d = codec.parse_report(d)
self.assertEqual(m, b'O'[0]) self.assertEqual(m, b'O'[0])
self.assertEqual(s, 0x01234567) self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59))) self.assertEqual(d, bytes(range(0, 59)))
t, l, d = wire_codec.parse_message(d) t, l, d = codec.parse_message(d)
self.assertEqual(t, 0x00010203) self.assertEqual(t, 0x00010203)
self.assertEqual(l, 0x04050607) self.assertEqual(l, 0x04050607)
self.assertEqual(d, bytes(range(8, 59))) self.assertEqual(d, bytes(range(8, 59)))
f, = wire_codec.parse_message_footer(d[0:4]) f, = codec.parse_message_footer(d[0:4])
self.assertEqual(f, 0x08090a0b) self.assertEqual(f, 0x08090a0b)
for i in range(0, 1024): for i in range(0, 1024):
if i != 64: if i != 64:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.parse_report(bytes(range(0, i))) codec.parse_report(bytes(range(0, i)))
if i != 59: if i != 59:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.parse_message(bytes(range(0, i))) codec.parse_message(bytes(range(0, i)))
if i != 4: if i != 4:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.parse_message_footer(bytes(range(0, i))) codec.parse_message_footer(bytes(range(0, i)))
def test_serialize(self): def test_serialize(self):
data = bytearray(range(0, 6)) data = bytearray(range(0, 6))
wire_codec.serialize_report_header(data, 0x12, 0x3456789a) codec.serialize_report_header(data, 0x12, 0x3456789a)
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05') self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 6)) data = bytearray(range(0, 6))
wire_codec.serialize_opened_session(data, 0x3456789a) codec.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05') self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 14)) data = bytearray(range(0, 14))
wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef) codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d') self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
data = bytearray(range(0, 5)) data = bytearray(range(0, 5))
wire_codec.serialize_message_footer(data, 0x89abcdef) codec.serialize_message_footer(data, 0x89abcdef)
self.assertEqual(data, b'\x89\xab\xcd\xef\x04') self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
for i in range(0, 13): for i in range(0, 13):
data = bytearray(i) data = bytearray(i)
if i < 4: if i < 4:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.serialize_message_footer(data, 0x00) codec.serialize_message_footer(data, 0x00)
if i < 5: if i < 5:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.serialize_report_header(data, 0x00, 0x00) codec.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.serialize_opened_session(data, 0x00) codec.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec.serialize_message_header(data, 0x00, 0x00) codec.serialize_message_header(data, 0x00, 0x00)
def test_decode_empty(self): def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51 message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -91,7 +91,7 @@ class TestWireCodec(unittest.TestCase):
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -110,7 +110,7 @@ class TestWireCodec(unittest.TestCase):
record = [] record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
self.assertEqual(res, None) self.assertEqual(res, None)
self.assertEqual(len(record), 2) self.assertEqual(len(record), 2)
self.assertEqual(record[0], data) self.assertEqual(record[0], data)
self.assertIsInstance(record[1], wire_codec.MessageChecksumError) self.assertIsInstance(record[1], codec.MessageChecksumError)
def test_decode_generated_range(self): def test_decode_generated_range(self):
for data_len in range(1, 512): for data_len in range(1, 512):
@ -137,7 +137,7 @@ class TestWireCodec(unittest.TestCase):
record = [] record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
res = 1 res = 1
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target) codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
self.assertEqual(len(record), 1) self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51) self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target) codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer]) self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self): def test_encode_generated_range(self):
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
target = genfunc() target = genfunc()
target.send(None) target.send(None)
wire_codec.encode_wire_message(msg_type, data, session_id, target) codec.encode_wire_message(msg_type, data, session_id, target)
self.assertEqual(received, len(reports)) self.assertEqual(received, len(reports))
def _record(self, record, *_args): def _record(self, record, *_args):

@ -1,12 +1,11 @@
from common import * from common import *
import ustruct import ustruct
import ubinascii
from trezor.crypto import random from trezor.crypto import random
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import wire_codec_v1 from trezor.wire import codec_v1
class TestWireCodecV1(unittest.TestCase): class TestWireCodecV1(unittest.TestCase):
# pylint: disable=C0301 # pylint: disable=C0301
@ -14,21 +13,21 @@ class TestWireCodecV1(unittest.TestCase):
def test_detect(self): def test_detect(self):
for i in range(0, 256): for i in range(0, 256):
if i == ord(b'?'): if i == ord(b'?'):
self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
else: else:
self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63)) self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
def test_parse(self): def test_parse(self):
d = bytes(range(0, 55)) d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m r = b'?' + m
rm, rs, rd = wire_codec_v1.parse_report_v1(r) rm, rs, rd = codec_v1.parse_report_v1(r)
self.assertEqual(rm, None) self.assertEqual(rm, None)
self.assertEqual(rs, 0) self.assertEqual(rs, 0)
self.assertEqual(rd, m) self.assertEqual(rd, m)
mt, ml, md = wire_codec_v1.parse_message(m) mt, ml, md = codec_v1.parse_message(m)
self.assertEqual(mt, 0) self.assertEqual(mt, 0)
self.assertEqual(ml, len(d)) self.assertEqual(ml, len(d))
self.assertEqual(md, d) self.assertEqual(md, d)
@ -36,34 +35,34 @@ class TestWireCodecV1(unittest.TestCase):
for i in range(0, 1024): for i in range(0, 1024):
if i != 64: if i != 64:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec_v1.parse_report_v1(bytes(range(0, i))) codec_v1.parse_report_v1(bytes(range(0, i)))
for hx in range(0, 256): for hx in range(0, 256):
for hy in range(0, 256): for hy in range(0, 256):
if hx != ord(b'#') and hy != ord(b'#'): if hx != ord(b'#') and hy != ord(b'#'):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:]) codec_v1.parse_message(bytes([hx, hy]) + m[2:])
def test_serialize(self): def test_serialize(self):
data = bytearray(range(0, 10)) data = bytearray(range(0, 10))
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09') self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
data = bytearray(9) data = bytearray(9)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 65536, 0) codec_v1.serialize_message_header(data, 65536, 0)
for i in range(0, 8): for i in range(0, 8):
data = bytearray(i) data = bytearray(i)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc) codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
def test_decode_empty(self): def test_decode_empty(self):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55 message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
record = [] record = []
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -80,7 +79,7 @@ class TestWireCodecV1(unittest.TestCase):
record = [] record = []
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy') genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
try: try:
@ -105,7 +104,7 @@ class TestWireCodecV1(unittest.TestCase):
record = [] record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy') genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy') decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None) decoder.send(None)
res = 1 res = 1
@ -125,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target) codec_v1.encode_wire_v1_message(0xabcd, b'', target)
self.assertEqual(len(record), 1) self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55) self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
@ -136,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)() target = self._record(record)()
target.send(None) target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, data, target) codec_v1.encode_wire_v1_message(0xabcd, data, target)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data]) self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self): def test_encode_generated_range(self):
@ -159,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase):
target = genfunc() target = genfunc()
target.send(None) target.send(None)
wire_codec_v1.encode_wire_v1_message(msg_type, data, target) codec_v1.encode_wire_v1_message(msg_type, data, target)
self.assertEqual(received, len(reports)) self.assertEqual(received, len(reports))
def _record(self, record, *_args): def _record(self, record, *_args):
Loading…
Cancel
Save