From 4a255e8b776eddcef9e2bf0d54f0d06686d7040e Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Wed, 18 May 2016 18:54:42 +0200 Subject: [PATCH] add write_message --- src/tests/test_msg.py | 30 ++++++++++++++--------- src/trezor/msg.py | 57 +++++++++++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 27 deletions(-) diff --git a/src/tests/test_msg.py b/src/tests/test_msg.py index 9bc3e347c..d515b7ec3 100644 --- a/src/tests/test_msg.py +++ b/src/tests/test_msg.py @@ -4,7 +4,13 @@ sys.path.append('../lib') import unittest from trezor import loop -from trezor.msg import read_report, parse_header, read_message +from trezor import msg +from trezor.msg import read_report, read_message, write_message + + +def chunks(l, n): + for i in range(0, len(l), n): + yield l[i:i + n] class TestMsg(unittest.TestCase): @@ -24,13 +30,6 @@ class TestMsg(unittest.TestCase): result = e.value self.assertEqual(result, empty_report) - def test_parse_header(self): - - report = b'\x3f##\xab\xcd\x12\x34\x56\x78' - msgtype, msglen = parse_header(report) - self.assertEqual(msgtype, int('0xabcd', 16)) - self.assertEqual(msglen, int('0x12345678', 16)) - def test_read_message(self): reader = read_message() @@ -56,10 +55,6 @@ class TestMsg(unittest.TestCase): self.assertEqual(restype, int('0xabcd', 16)) self.assertEqual(resmsg, content) - def chunks(l, n): - for i in range(0, len(l), n): - yield l[i:i + n] - reader = read_message() reader.send(None) @@ -74,6 +69,17 @@ class TestMsg(unittest.TestCase): self.assertEqual(restype, int('0xabcd', 16)) self.assertEqual(resmsg, content) + def test_write_message(self): + + written_reports = [] + msg.write_report = lambda report: written_reports.append(bytes(report)) + + content = bytes([x for x in range(0, 256)]) + message = b'##\xab\xcd\x00\x00\x01\00' + content + reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)] + write_message(int('0xabcd'), content) + self.assertEqual(written_reports, reports) + if __name__ == '__main__': unittest.main() diff --git a/src/trezor/msg.py b/src/trezor/msg.py index 742cf81e3..3a5497610 100644 --- a/src/trezor/msg.py +++ b/src/trezor/msg.py @@ -15,38 +15,65 @@ def send(msg): return _msg.send(msg) -REPORT_LEN = 64 -REPORT_NUM = 63 -HEADER_MAGIC = 35 # '#' +REPORT_LEN = const(64) +REPORT_NUM = const(63) +HEADER_MAGIC = const(35) # '#' def read_report(): report = yield loop.Select(loop.HID_READ) - assert report[0] == REPORT_NUM, 'Malformed report number' - assert len(report) == REPORT_LEN, 'Incorrect report length' - return memoryview(report) + assert report[0] == REPORT_NUM + return report -def parse_header(report): - assert report[1] == HEADER_MAGIC and report[2] == HEADER_MAGIC, 'Header not found' - return ustruct.unpack_from('>HL', report, 3) +def write_report(report): + return send(report) # FIXME def read_message(): report = yield from read_report() - (msgtype, msglen) = parse_header(report) + assert report[1] == HEADER_MAGIC + assert report[2] == HEADER_MAGIC + (msgtype, msglen) = ustruct.unpack_from('>HL', report, 3) - repdata = report[1 + 8:] + # TODO: validate msglen for sane values + + report = memoryview(report) + repdata = report[9:] repdata = repdata[:msglen] - msgbuf = bytearray(repdata) - remaining = msglen - len(msgbuf) + msgdata = bytearray(repdata) # TODO: allocate msglen bytes + remaining = msglen - len(msgdata) while remaining > 0: report = yield from read_report() + report = memoryview(report) repdata = report[1:] repdata = repdata[:remaining] - msgbuf.extend(repdata) + msgdata.extend(repdata) remaining -= len(repdata) - return (msgtype, msgbuf) + return (msgtype, msgdata) + + +def write_message(msgtype, msgdata): + report = bytearray(REPORT_LEN) + report[0] = REPORT_NUM + report[1] = HEADER_MAGIC + report[2] = HEADER_MAGIC + ustruct.pack_into('>HL', report, 3, msgtype, len(msgdata)) + + msgdata = memoryview(msgdata) + report = memoryview(report) + repdata = report[9:] + + while msgdata: + n = min(len(repdata), len(msgdata)) + repdata[:n] = msgdata[:n] + i = n + while i < len(repdata): + repdata[i] = 0 + i += 1 + write_report(report) + msgdata = msgdata[n:] + repdata = report[1:]