trezor.wire: rename modules

pull/25/head
Jan Pochyla 8 years ago
parent b145f8f309
commit 1b27bb480d

@ -6,12 +6,12 @@ from trezor.messages import get_protobuf_type
from trezor.workflow import start_workflow
from trezor import log
from .wire_io import read_report_stream, write_report_stream
from .wire_dispatcher import dispatch_reports_by_session
from .wire_codec import \
from .io import read_report_stream, write_report_stream
from .dispatcher import dispatch_reports_by_session
from .codec import \
decode_wire_stream, encode_wire_message, \
encode_session_open_message, encode_session_close_message
from .wire_codec_v1 import \
from .codec_v1 import \
SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
_session_handlers = {} # session id -> generator

@ -1,6 +1,6 @@
from trezor import log
from .wire_codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
from .wire_codec_v1 import detect_v1, parse_report_v1
from .codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
from .codec_v1 import detect_v1, parse_report_v1
def dispatch_reports_by_session(handlers,

@ -1,5 +1,5 @@
from ubinascii import hexlify
from micropython import const
from ubinascii import hexlify
from trezor import msg, loop, log
_DEFAULT_IFACE = const(0xFF00) # TODO: use proper interface
@ -15,5 +15,5 @@ def read_report_stream(target, iface=_DEFAULT_IFACE):
def write_report_stream(iface=_DEFAULT_IFACE):
while True:
report = yield
log.debug(__name__, 'write report %s', hexlify(report))
log.info(__name__, 'write report %s', hexlify(report))
msg.send(iface, report)

@ -6,7 +6,7 @@ import ubinascii
from trezor.crypto import random
from trezor.utils import chunks
from trezor.wire import wire_codec
from trezor.wire import codec
class TestWireCodec(unittest.TestCase):
# pylint: disable=C0301
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
def test_parse(self):
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
m, s, d = wire_codec.parse_report(d)
m, s, d = codec.parse_report(d)
self.assertEqual(m, b'O'[0])
self.assertEqual(s, 0x01234567)
self.assertEqual(d, bytes(range(0, 59)))
t, l, d = wire_codec.parse_message(d)
t, l, d = codec.parse_message(d)
self.assertEqual(t, 0x00010203)
self.assertEqual(l, 0x04050607)
self.assertEqual(d, bytes(range(8, 59)))
f, = wire_codec.parse_message_footer(d[0:4])
f, = codec.parse_message_footer(d[0:4])
self.assertEqual(f, 0x08090a0b)
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
wire_codec.parse_report(bytes(range(0, i)))
codec.parse_report(bytes(range(0, i)))
if i != 59:
with self.assertRaises(ValueError):
wire_codec.parse_message(bytes(range(0, i)))
codec.parse_message(bytes(range(0, i)))
if i != 4:
with self.assertRaises(ValueError):
wire_codec.parse_message_footer(bytes(range(0, i)))
codec.parse_message_footer(bytes(range(0, i)))
def test_serialize(self):
data = bytearray(range(0, 6))
wire_codec.serialize_report_header(data, 0x12, 0x3456789a)
codec.serialize_report_header(data, 0x12, 0x3456789a)
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 6))
wire_codec.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
codec.serialize_opened_session(data, 0x3456789a)
self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
data = bytearray(range(0, 14))
wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
data = bytearray(range(0, 5))
wire_codec.serialize_message_footer(data, 0x89abcdef)
codec.serialize_message_footer(data, 0x89abcdef)
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
for i in range(0, 13):
data = bytearray(i)
if i < 4:
with self.assertRaises(ValueError):
wire_codec.serialize_message_footer(data, 0x00)
codec.serialize_message_footer(data, 0x00)
if i < 5:
with self.assertRaises(ValueError):
wire_codec.serialize_report_header(data, 0x00, 0x00)
codec.serialize_report_header(data, 0x00, 0x00)
with self.assertRaises(ValueError):
wire_codec.serialize_opened_session(data, 0x00)
codec.serialize_opened_session(data, 0x00)
with self.assertRaises(ValueError):
wire_codec.serialize_message_header(data, 0x00, 0x00)
codec.serialize_message_header(data, 0x00, 0x00)
def test_decode_empty(self):
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
record = []
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
@ -91,7 +91,7 @@ class TestWireCodec(unittest.TestCase):
record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
@ -110,7 +110,7 @@ class TestWireCodec(unittest.TestCase):
record = []
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
self.assertEqual(res, None)
self.assertEqual(len(record), 2)
self.assertEqual(record[0], data)
self.assertIsInstance(record[1], wire_codec.MessageChecksumError)
self.assertIsInstance(record[1], codec.MessageChecksumError)
def test_decode_generated_range(self):
for data_len in range(1, 512):
@ -137,7 +137,7 @@ class TestWireCodec(unittest.TestCase):
record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
res = 1
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)()
target.send(None)
wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
target = self._record(record)()
target.send(None)
wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
def test_encode_generated_range(self):
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
target = genfunc()
target.send(None)
wire_codec.encode_wire_message(msg_type, data, session_id, target)
codec.encode_wire_message(msg_type, data, session_id, target)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):

@ -1,12 +1,11 @@
from common import *
import ustruct
import ubinascii
from trezor.crypto import random
from trezor.utils import chunks
from trezor.wire import wire_codec_v1
from trezor.wire import codec_v1
class TestWireCodecV1(unittest.TestCase):
# pylint: disable=C0301
@ -14,21 +13,21 @@ class TestWireCodecV1(unittest.TestCase):
def test_detect(self):
for i in range(0, 256):
if i == ord(b'?'):
self.assertTrue(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
self.assertTrue(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
else:
self.assertFalse(wire_codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
self.assertFalse(codec_v1.detect_v1(bytes([i]) + b'\x00' * 63))
def test_parse(self):
d = bytes(range(0, 55))
m = b'##\x00\x00\x00\x00\x00\x37' + d
r = b'?' + m
rm, rs, rd = wire_codec_v1.parse_report_v1(r)
rm, rs, rd = codec_v1.parse_report_v1(r)
self.assertEqual(rm, None)
self.assertEqual(rs, 0)
self.assertEqual(rd, m)
mt, ml, md = wire_codec_v1.parse_message(m)
mt, ml, md = codec_v1.parse_message(m)
self.assertEqual(mt, 0)
self.assertEqual(ml, len(d))
self.assertEqual(md, d)
@ -36,34 +35,34 @@ class TestWireCodecV1(unittest.TestCase):
for i in range(0, 1024):
if i != 64:
with self.assertRaises(ValueError):
wire_codec_v1.parse_report_v1(bytes(range(0, i)))
codec_v1.parse_report_v1(bytes(range(0, i)))
for hx in range(0, 256):
for hy in range(0, 256):
if hx != ord(b'#') and hy != ord(b'#'):
with self.assertRaises(ValueError):
wire_codec_v1.parse_message(bytes([hx, hy]) + m[2:])
codec_v1.parse_message(bytes([hx, hy]) + m[2:])
def test_serialize(self):
data = bytearray(range(0, 10))
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
self.assertEqual(data, b'\x00##\x12\x34\x56\x78\x9a\xbc\x09')
data = bytearray(9)
with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 65536, 0)
codec_v1.serialize_message_header(data, 65536, 0)
for i in range(0, 8):
data = bytearray(i)
with self.assertRaises(ValueError):
wire_codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
codec_v1.serialize_message_header(data, 0x1234, 0x56789abc)
def test_decode_empty(self):
message = b'##' + b'\xab\xcd' + b'\x00\x00\x00\x00' + b'\x00' * 55
record = []
genfunc = self._record(record, 0xabcd, 0, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
@ -80,7 +79,7 @@ class TestWireCodecV1(unittest.TestCase):
record = []
genfunc = self._record(record, 0xabcd, 55, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
try:
@ -105,7 +104,7 @@ class TestWireCodecV1(unittest.TestCase):
record = []
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
decoder = wire_codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder = codec_v1.decode_wire_v1_stream(genfunc, 0xdeadbeef, 'dummy')
decoder.send(None)
res = 1
@ -125,7 +124,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)()
target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, b'', target)
codec_v1.encode_wire_v1_message(0xabcd, b'', target)
self.assertEqual(len(record), 1)
self.assertEqual(record[0], b'?##\xab\xcd\x00\x00\x00\x00' + '\0' * 55)
@ -136,7 +135,7 @@ class TestWireCodecV1(unittest.TestCase):
target = self._record(record)()
target.send(None)
wire_codec_v1.encode_wire_v1_message(0xabcd, data, target)
codec_v1.encode_wire_v1_message(0xabcd, data, target)
self.assertEqual(record, [b'?##\xab\xcd\x00\x00\x00\x37' + data])
def test_encode_generated_range(self):
@ -159,7 +158,7 @@ class TestWireCodecV1(unittest.TestCase):
target = genfunc()
target.send(None)
wire_codec_v1.encode_wire_v1_message(msg_type, data, target)
codec_v1.encode_wire_v1_message(msg_type, data, target)
self.assertEqual(received, len(reports))
def _record(self, record, *_args):
Loading…
Cancel
Save