mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-16 04:29:08 +00:00
add write_message
This commit is contained in:
parent
f98fc4c0c8
commit
4a255e8b77
@ -4,7 +4,13 @@ sys.path.append('../lib')
|
||||
import unittest
|
||||
|
||||
from trezor import loop
|
||||
from trezor.msg import read_report, parse_header, read_message
|
||||
from trezor import msg
|
||||
from trezor.msg import read_report, read_message, write_message
|
||||
|
||||
|
||||
def chunks(l, n):
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
class TestMsg(unittest.TestCase):
|
||||
@ -24,13 +30,6 @@ class TestMsg(unittest.TestCase):
|
||||
result = e.value
|
||||
self.assertEqual(result, empty_report)
|
||||
|
||||
def test_parse_header(self):
|
||||
|
||||
report = b'\x3f##\xab\xcd\x12\x34\x56\x78'
|
||||
msgtype, msglen = parse_header(report)
|
||||
self.assertEqual(msgtype, int('0xabcd', 16))
|
||||
self.assertEqual(msglen, int('0x12345678', 16))
|
||||
|
||||
def test_read_message(self):
|
||||
|
||||
reader = read_message()
|
||||
@ -56,10 +55,6 @@ class TestMsg(unittest.TestCase):
|
||||
self.assertEqual(restype, int('0xabcd', 16))
|
||||
self.assertEqual(resmsg, content)
|
||||
|
||||
def chunks(l, n):
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
reader = read_message()
|
||||
reader.send(None)
|
||||
|
||||
@ -74,6 +69,17 @@ class TestMsg(unittest.TestCase):
|
||||
self.assertEqual(restype, int('0xabcd', 16))
|
||||
self.assertEqual(resmsg, content)
|
||||
|
||||
def test_write_message(self):
|
||||
|
||||
written_reports = []
|
||||
msg.write_report = lambda report: written_reports.append(bytes(report))
|
||||
|
||||
content = bytes([x for x in range(0, 256)])
|
||||
message = b'##\xab\xcd\x00\x00\x01\00' + content
|
||||
reports = [b'\x3f' + ch + '\x00' * (63 - len(ch)) for ch in chunks(message, 63)]
|
||||
write_message(int('0xabcd'), content)
|
||||
self.assertEqual(written_reports, reports)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -15,38 +15,65 @@ def send(msg):
|
||||
return _msg.send(msg)
|
||||
|
||||
|
||||
REPORT_LEN = 64
|
||||
REPORT_NUM = 63
|
||||
HEADER_MAGIC = 35 # '#'
|
||||
REPORT_LEN = const(64)
|
||||
REPORT_NUM = const(63)
|
||||
HEADER_MAGIC = const(35) # '#'
|
||||
|
||||
|
||||
def read_report():
|
||||
report = yield loop.Select(loop.HID_READ)
|
||||
assert report[0] == REPORT_NUM, 'Malformed report number'
|
||||
assert len(report) == REPORT_LEN, 'Incorrect report length'
|
||||
return memoryview(report)
|
||||
assert report[0] == REPORT_NUM
|
||||
return report
|
||||
|
||||
|
||||
def parse_header(report):
|
||||
assert report[1] == HEADER_MAGIC and report[2] == HEADER_MAGIC, 'Header not found'
|
||||
return ustruct.unpack_from('>HL', report, 3)
|
||||
def write_report(report):
|
||||
return send(report) # FIXME
|
||||
|
||||
|
||||
def read_message():
|
||||
report = yield from read_report()
|
||||
(msgtype, msglen) = parse_header(report)
|
||||
assert report[1] == HEADER_MAGIC
|
||||
assert report[2] == HEADER_MAGIC
|
||||
(msgtype, msglen) = ustruct.unpack_from('>HL', report, 3)
|
||||
|
||||
repdata = report[1 + 8:]
|
||||
# TODO: validate msglen for sane values
|
||||
|
||||
report = memoryview(report)
|
||||
repdata = report[9:]
|
||||
repdata = repdata[:msglen]
|
||||
msgbuf = bytearray(repdata)
|
||||
|
||||
remaining = msglen - len(msgbuf)
|
||||
msgdata = bytearray(repdata) # TODO: allocate msglen bytes
|
||||
remaining = msglen - len(msgdata)
|
||||
|
||||
while remaining > 0:
|
||||
report = yield from read_report()
|
||||
report = memoryview(report)
|
||||
repdata = report[1:]
|
||||
repdata = repdata[:remaining]
|
||||
msgbuf.extend(repdata)
|
||||
msgdata.extend(repdata)
|
||||
remaining -= len(repdata)
|
||||
|
||||
return (msgtype, msgbuf)
|
||||
return (msgtype, msgdata)
|
||||
|
||||
|
||||
def write_message(msgtype, msgdata):
|
||||
report = bytearray(REPORT_LEN)
|
||||
report[0] = REPORT_NUM
|
||||
report[1] = HEADER_MAGIC
|
||||
report[2] = HEADER_MAGIC
|
||||
ustruct.pack_into('>HL', report, 3, msgtype, len(msgdata))
|
||||
|
||||
msgdata = memoryview(msgdata)
|
||||
report = memoryview(report)
|
||||
repdata = report[9:]
|
||||
|
||||
while msgdata:
|
||||
n = min(len(repdata), len(msgdata))
|
||||
repdata[:n] = msgdata[:n]
|
||||
i = n
|
||||
while i < len(repdata):
|
||||
repdata[i] = 0
|
||||
i += 1
|
||||
write_report(report)
|
||||
msgdata = msgdata[n:]
|
||||
repdata = report[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user