diff --git a/src/tests/test_msg.py b/src/tests/test_msg.py new file mode 100644 index 0000000000..9bc3e347cf --- /dev/null +++ b/src/tests/test_msg.py @@ -0,0 +1,79 @@ +import sys +sys.path.append('..') +sys.path.append('../lib') +import unittest + +from trezor import loop +from trezor.msg import read_report, parse_header, read_message + + +class TestMsg(unittest.TestCase): + + def test_read_report(self): + + reader = read_report() + syscall = reader.send(None) + + self.assertIsInstance(syscall, loop.Select) + self.assertEqual(syscall.events, (loop.HID_READ,)) + + empty_report = b'\x3f' + b'\x00' * 63 + try: + reader.send(empty_report) + except StopIteration as e: + result = e.value + self.assertEqual(result, empty_report) + + def test_parse_header(self): + + report = b'\x3f##\xab\xcd\x12\x34\x56\x78' + msgtype, msglen = parse_header(report) + self.assertEqual(msgtype, int('0xabcd', 16)) + self.assertEqual(msglen, int('0x12345678', 16)) + + def test_read_message(self): + + reader = read_message() + reader.send(None) + + empty_message = b'\x3f##\xab\xcd\x00\x00\x00\x00' + b'\x00' * 55 + try: + reader.send(empty_message) + except StopIteration as e: + restype, resmsg = e.value + self.assertEqual(restype, int('0xabcd', 16)) + self.assertEqual(resmsg, b'') + + reader = read_message() + reader.send(None) + + content = bytes([x for x in range(0, 55)]) + message = b'\x3f##\xab\xcd\x00\x00\x00\x37' + content + try: + reader.send(message) + except StopIteration as e: + restype, resmsg = e.value + self.assertEqual(restype, int('0xabcd', 16)) + self.assertEqual(resmsg, content) + + def chunks(l, n): + for i in range(0, len(l), n): + yield l[i:i + n] + + reader = read_message() + reader.send(None) + + content = bytes([x for x in range(0, 256)]) + message = b'##\xab\xcd\x00\x00\x01\00' + content + reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] + try: + for report in reports: + reader.send(report) + except StopIteration as e: + restype, resmsg = e.value + self.assertEqual(restype, int('0xabcd', 16)) + self.assertEqual(resmsg, content) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/trezor/msg.py b/src/trezor/msg.py index f815899c3b..742cf81e36 100644 --- a/src/trezor/msg.py +++ b/src/trezor/msg.py @@ -1,9 +1,52 @@ +import ustruct + +from trezor import loop + from TrezorMsg import Msg _msg = Msg() + def select(timeout_us): return _msg.select(timeout_us) + def send(msg): return _msg.send(msg) + + +REPORT_LEN = 64 +REPORT_NUM = 63 +HEADER_MAGIC = 35 # '#' + + +def read_report(): + report = yield loop.Select(loop.HID_READ) + assert report[0] == REPORT_NUM, 'Malformed report number' + assert len(report) == REPORT_LEN, 'Incorrect report length' + return memoryview(report) + + +def parse_header(report): + assert report[1] == HEADER_MAGIC and report[2] == HEADER_MAGIC, 'Header not found' + return ustruct.unpack_from('>HL', report, 3) + + +def read_message(): + report = yield from read_report() + (msgtype, msglen) = parse_header(report) + + repdata = report[1 + 8:] + repdata = repdata[:msglen] + msgbuf = bytearray(repdata) + + remaining = msglen - len(msgbuf) + + while remaining > 0: + report = yield from read_report() + repdata = report[1:] + repdata = repdata[:remaining] + msgbuf.extend(repdata) + remaining -= len(repdata) + + return (msgtype, msgbuf)