|
|
@ -6,7 +6,7 @@ import ubinascii
|
|
|
|
from trezor.crypto import random
|
|
|
|
from trezor.crypto import random
|
|
|
|
from trezor.utils import chunks
|
|
|
|
from trezor.utils import chunks
|
|
|
|
|
|
|
|
|
|
|
|
from trezor.wire import wire_codec
|
|
|
|
from trezor.wire import codec
|
|
|
|
|
|
|
|
|
|
|
|
class TestWireCodec(unittest.TestCase):
|
|
|
|
class TestWireCodec(unittest.TestCase):
|
|
|
|
# pylint: disable=C0301
|
|
|
|
# pylint: disable=C0301
|
|
|
@ -14,66 +14,66 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
def test_parse(self):
|
|
|
|
def test_parse(self):
|
|
|
|
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
|
|
|
d = b'O' + b'\x01\x23\x45\x67' + bytes(range(0, 59))
|
|
|
|
|
|
|
|
|
|
|
|
m, s, d = wire_codec.parse_report(d)
|
|
|
|
m, s, d = codec.parse_report(d)
|
|
|
|
self.assertEqual(m, b'O'[0])
|
|
|
|
self.assertEqual(m, b'O'[0])
|
|
|
|
self.assertEqual(s, 0x01234567)
|
|
|
|
self.assertEqual(s, 0x01234567)
|
|
|
|
self.assertEqual(d, bytes(range(0, 59)))
|
|
|
|
self.assertEqual(d, bytes(range(0, 59)))
|
|
|
|
|
|
|
|
|
|
|
|
t, l, d = wire_codec.parse_message(d)
|
|
|
|
t, l, d = codec.parse_message(d)
|
|
|
|
self.assertEqual(t, 0x00010203)
|
|
|
|
self.assertEqual(t, 0x00010203)
|
|
|
|
self.assertEqual(l, 0x04050607)
|
|
|
|
self.assertEqual(l, 0x04050607)
|
|
|
|
self.assertEqual(d, bytes(range(8, 59)))
|
|
|
|
self.assertEqual(d, bytes(range(8, 59)))
|
|
|
|
|
|
|
|
|
|
|
|
f, = wire_codec.parse_message_footer(d[0:4])
|
|
|
|
f, = codec.parse_message_footer(d[0:4])
|
|
|
|
self.assertEqual(f, 0x08090a0b)
|
|
|
|
self.assertEqual(f, 0x08090a0b)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, 1024):
|
|
|
|
for i in range(0, 1024):
|
|
|
|
if i != 64:
|
|
|
|
if i != 64:
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.parse_report(bytes(range(0, i)))
|
|
|
|
codec.parse_report(bytes(range(0, i)))
|
|
|
|
if i != 59:
|
|
|
|
if i != 59:
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.parse_message(bytes(range(0, i)))
|
|
|
|
codec.parse_message(bytes(range(0, i)))
|
|
|
|
if i != 4:
|
|
|
|
if i != 4:
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.parse_message_footer(bytes(range(0, i)))
|
|
|
|
codec.parse_message_footer(bytes(range(0, i)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_serialize(self):
|
|
|
|
def test_serialize(self):
|
|
|
|
data = bytearray(range(0, 6))
|
|
|
|
data = bytearray(range(0, 6))
|
|
|
|
wire_codec.serialize_report_header(data, 0x12, 0x3456789a)
|
|
|
|
codec.serialize_report_header(data, 0x12, 0x3456789a)
|
|
|
|
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
|
|
|
self.assertEqual(data, b'\x12\x34\x56\x78\x9a\x05')
|
|
|
|
|
|
|
|
|
|
|
|
data = bytearray(range(0, 6))
|
|
|
|
data = bytearray(range(0, 6))
|
|
|
|
wire_codec.serialize_opened_session(data, 0x3456789a)
|
|
|
|
codec.serialize_opened_session(data, 0x3456789a)
|
|
|
|
self.assertEqual(data, bytes([wire_codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
|
|
|
self.assertEqual(data, bytes([codec.REP_MARKER_OPEN]) + b'\x34\x56\x78\x9a\x05')
|
|
|
|
|
|
|
|
|
|
|
|
data = bytearray(range(0, 14))
|
|
|
|
data = bytearray(range(0, 14))
|
|
|
|
wire_codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
|
|
|
codec.serialize_message_header(data, 0x01234567, 0x89abcdef)
|
|
|
|
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
|
|
|
self.assertEqual(data, b'\x00\x01\x02\x03\x04\x01\x23\x45\x67\x89\xab\xcd\xef\x0d')
|
|
|
|
|
|
|
|
|
|
|
|
data = bytearray(range(0, 5))
|
|
|
|
data = bytearray(range(0, 5))
|
|
|
|
wire_codec.serialize_message_footer(data, 0x89abcdef)
|
|
|
|
codec.serialize_message_footer(data, 0x89abcdef)
|
|
|
|
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
|
|
|
self.assertEqual(data, b'\x89\xab\xcd\xef\x04')
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(0, 13):
|
|
|
|
for i in range(0, 13):
|
|
|
|
data = bytearray(i)
|
|
|
|
data = bytearray(i)
|
|
|
|
if i < 4:
|
|
|
|
if i < 4:
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.serialize_message_footer(data, 0x00)
|
|
|
|
codec.serialize_message_footer(data, 0x00)
|
|
|
|
if i < 5:
|
|
|
|
if i < 5:
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.serialize_report_header(data, 0x00, 0x00)
|
|
|
|
codec.serialize_report_header(data, 0x00, 0x00)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.serialize_opened_session(data, 0x00)
|
|
|
|
codec.serialize_opened_session(data, 0x00)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
wire_codec.serialize_message_header(data, 0x00, 0x00)
|
|
|
|
codec.serialize_message_header(data, 0x00, 0x00)
|
|
|
|
|
|
|
|
|
|
|
|
def test_decode_empty(self):
|
|
|
|
def test_decode_empty(self):
|
|
|
|
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
|
|
|
message = b'\xab\xcd\xef\x12' + b'\x00\x00\x00\x00' + b'\x00' * 51
|
|
|
|
|
|
|
|
|
|
|
|
record = []
|
|
|
|
record = []
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 0, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder.send(None)
|
|
|
|
decoder.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -91,7 +91,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
record = []
|
|
|
|
record = []
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder.send(None)
|
|
|
|
decoder.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -110,7 +110,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
record = []
|
|
|
|
record = []
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
|
|
|
genfunc = self._record(record, 0xabcdef12, 47, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder.send(None)
|
|
|
|
decoder.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -120,7 +120,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
self.assertEqual(res, None)
|
|
|
|
self.assertEqual(res, None)
|
|
|
|
self.assertEqual(len(record), 2)
|
|
|
|
self.assertEqual(len(record), 2)
|
|
|
|
self.assertEqual(record[0], data)
|
|
|
|
self.assertEqual(record[0], data)
|
|
|
|
self.assertIsInstance(record[1], wire_codec.MessageChecksumError)
|
|
|
|
self.assertIsInstance(record[1], codec.MessageChecksumError)
|
|
|
|
|
|
|
|
|
|
|
|
def test_decode_generated_range(self):
|
|
|
|
def test_decode_generated_range(self):
|
|
|
|
for data_len in range(1, 512):
|
|
|
|
for data_len in range(1, 512):
|
|
|
@ -137,7 +137,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
record = []
|
|
|
|
record = []
|
|
|
|
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
|
|
|
genfunc = self._record(record, msg_type, data_len, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = wire_codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder = codec.decode_wire_stream(genfunc, 0xdeadbeef, 'dummy')
|
|
|
|
decoder.send(None)
|
|
|
|
decoder.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
res = 1
|
|
|
|
res = 1
|
|
|
@ -157,7 +157,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
target = self._record(record)()
|
|
|
|
target = self._record(record)()
|
|
|
|
target.send(None)
|
|
|
|
target.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
wire_codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
|
|
|
|
codec.encode_wire_message(0xabcdef12, b'', 0xdeadbeef, target)
|
|
|
|
self.assertEqual(len(record), 1)
|
|
|
|
self.assertEqual(len(record), 1)
|
|
|
|
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
|
|
|
self.assertEqual(record[0], b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x00' + '\0' * 51)
|
|
|
|
|
|
|
|
|
|
|
@ -169,7 +169,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
target = self._record(record)()
|
|
|
|
target = self._record(record)()
|
|
|
|
target.send(None)
|
|
|
|
target.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
wire_codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
|
|
|
|
codec.encode_wire_message(0xabcdef12, data, 0xdeadbeef, target)
|
|
|
|
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
|
|
|
self.assertEqual(record, [b'H\xde\xad\xbe\xef\xab\xcd\xef\x12\x00\x00\x00\x2f' + data + footer])
|
|
|
|
|
|
|
|
|
|
|
|
def test_encode_generated_range(self):
|
|
|
|
def test_encode_generated_range(self):
|
|
|
@ -199,7 +199,7 @@ class TestWireCodec(unittest.TestCase):
|
|
|
|
target = genfunc()
|
|
|
|
target = genfunc()
|
|
|
|
target.send(None)
|
|
|
|
target.send(None)
|
|
|
|
|
|
|
|
|
|
|
|
wire_codec.encode_wire_message(msg_type, data, session_id, target)
|
|
|
|
codec.encode_wire_message(msg_type, data, session_id, target)
|
|
|
|
self.assertEqual(received, len(reports))
|
|
|
|
self.assertEqual(received, len(reports))
|
|
|
|
|
|
|
|
|
|
|
|
def _record(self, record, *_args):
|
|
|
|
def _record(self, record, *_args):
|