diff --git a/src/tests/test_msg.py b/src/tests/test_msg.py index d515b7ec3..3b6c0b2da 100644 --- a/src/tests/test_msg.py +++ b/src/tests/test_msg.py @@ -5,7 +5,7 @@ import unittest from trezor import loop from trezor import msg -from trezor.msg import read_report, read_message, write_message +from trezor.msg import read_wire_msg, write_wire_msg def chunks(l, n): @@ -15,47 +15,32 @@ def chunks(l, n): class TestMsg(unittest.TestCase): - def test_read_report(self): + def test_read_wire_msg(self): - reader = read_report() - syscall = reader.send(None) - - self.assertIsInstance(syscall, loop.Select) - self.assertEqual(syscall.events, (loop.HID_READ,)) - - empty_report = b'\x3f' + b'\x00' * 63 - try: - reader.send(empty_report) - except StopIteration as e: - result = e.value - self.assertEqual(result, empty_report) - - def test_read_message(self): - - reader = read_message() + reader = read_wire_msg() reader.send(None) empty_message = b'\x3f##\xab\xcd\x00\x00\x00\x00' + b'\x00' * 55 try: - reader.send(empty_message) + reader.send((loop.HID_READ, empty_message)) except StopIteration as e: restype, resmsg = e.value self.assertEqual(restype, int('0xabcd', 16)) self.assertEqual(resmsg, b'') - reader = read_message() + reader = read_wire_msg() reader.send(None) content = bytes([x for x in range(0, 55)]) message = b'\x3f##\xab\xcd\x00\x00\x00\x37' + content try: - reader.send(message) + reader.send((loop.HID_READ, message)) except StopIteration as e: restype, resmsg = e.value self.assertEqual(restype, int('0xabcd', 16)) self.assertEqual(resmsg, content) - reader = read_message() + reader = read_wire_msg() reader.send(None) content = bytes([x for x in range(0, 256)]) @@ -63,22 +48,22 @@ class TestMsg(unittest.TestCase): reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] try: for report in reports: - reader.send(report) + reader.send((loop.HID_READ, report)) except StopIteration as e: restype, resmsg = e.value self.assertEqual(restype, int('0xabcd', 16)) self.assertEqual(resmsg, content) - def test_write_message(self): + def test_write_wire_msg(self): - written_reports = [] - msg.write_report = lambda report: written_reports.append(bytes(report)) + sent_reps = [] + msg.send = lambda rep: sent_reps.append(bytes(rep)) content = bytes([x for x in range(0, 256)]) message = b'##\xab\xcd\x00\x00\x01\00' + content reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] - write_message(int('0xabcd'), content) - self.assertEqual(written_reports, reports) + write_wire_msg(int('0xabcd'), content) + self.assertEqual(sent_reps, reports) if __name__ == '__main__':