1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 01:18:28 +00:00

use memoryviews in wire_codec

This commit is contained in:
Jan Pochyla 2016-09-25 15:35:47 +02:00 committed by Pavol Rusnak
parent 7570977cc4
commit 64746d01b4
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

View File

@ -35,7 +35,6 @@ _MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
def parse_report(data): def parse_report(data):
marker, session_id = ustruct.unpack(_REP_HEADER, data) marker, session_id = ustruct.unpack(_REP_HEADER, data)
# TODO: handle v1 protocol
return marker, session_id, data[_REP_HEADER_LEN:] return marker, session_id, data[_REP_HEADER_LEN:]
@ -77,7 +76,7 @@ Sends (msg_type, data_len) to target, followed by data chunks.
Throws EOFError after last data chunk, in case of valid checksum. Throws EOFError after last data chunk, in case of valid checksum.
Throws MessageChecksumError to target if data doesn't match the checksum. Throws MessageChecksumError to target if data doesn't match the checksum.
''' '''
message = yield # read first report message = memoryview((yield)) # read first report
msg_type, data_len, data_tail = parse_message(message) msg_type, data_len, data_tail = parse_message(message)
target = genfunc(msg_type, data_len, session_id, *args) target = genfunc(msg_type, data_len, session_id, *args)
@ -86,11 +85,9 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
checksum = 0 # crc32 checksum = 0 # crc32
nreports = 1 nreports = 1
compute_checksum = hasattr(ubinascii, 'crc32')
while data_len > 0: while data_len > 0:
if nreports > 1: if nreports > 1:
data_tail = yield # read next report data_tail = memoryview((yield)) # read next report
nreports += 1 nreports += 1
data_chunk = data_tail[:data_len] # slice off the garbage at the end data_chunk = data_tail[:data_len] # slice off the garbage at the end
@ -98,46 +95,35 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
data_len -= len(data_chunk) data_len -= len(data_chunk)
target.send(data_chunk) target.send(data_chunk)
if compute_checksum: checksum = ubinascii.crc32(data_chunk, checksum)
checksum = ubinascii.crc32(data_chunk, checksum) & 0xffffffff
msg_footer = data_tail[:_MSG_FOOTER_LEN] msg_footer = data_tail[:_MSG_FOOTER_LEN]
if len(msg_footer) < _MSG_FOOTER_LEN: if len(msg_footer) < _MSG_FOOTER_LEN:
data_tail = yield # read report with the rest of checksum data_tail = yield # read report with the rest of checksum
msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)] msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
if compute_checksum: data_checksum, = parse_message_footer(msg_footer)
data_checksum, = parse_message_footer(msg_footer)
else:
data_checksum = checksum
if data_checksum != checksum: if data_checksum != checksum:
target.throw(MessageChecksumError( target.throw(MessageChecksumError((checksum, data_checksum)))
'Message checksum mismatch, expected %d, received %d' % (checksum, data_checksum)))
else: else:
target.throw(EOFError()) target.throw(EOFError())
def encode_wire_message(msg_type, msg_data, session_id, target): def encode_wire_message(msg_type, msg_data, session_id, target):
report = bytearray(_REP_LEN) report = memoryview(bytearray(_REP_LEN))
serialize_report_header(report, REP_MARKER_HEADER, session_id) serialize_report_header(report, REP_MARKER_HEADER, session_id)
serialize_message_header(report, msg_type, len(msg_data)) serialize_message_header(report, msg_type, len(msg_data))
msg_data = memoryview(msg_data) source_data = memoryview(msg_data)
report = memoryview(report)
source_data = msg_data
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:] target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
compute_checksum = hasattr(ubinascii, 'crc32') checksum = ubinascii.crc32(msg_data)
if compute_checksum:
checksum = ubinascii.crc32(msg_data) & 0xffffffff
else:
checksum = 0
msg_footer = bytearray(_MSG_FOOTER_LEN) msg_footer = bytearray(_MSG_FOOTER_LEN)
serialize_message_footer(msg_footer, checksum) serialize_message_footer(msg_footer, checksum)
first = True
while True: while True:
# move as much as possible from source to target # move as much as possible from source to target
n = min(len(target_data), len(source_data)) n = min(len(target_data), len(source_data))
@ -157,8 +143,10 @@ def encode_wire_message(msg_type, msg_data, session_id, target):
break break
# reset to skip the magic and session ID # reset to skip the magic and session ID
serialize_report_header(report, REP_MARKER_DATA, session_id) if first:
target_data = report[_REP_HEADER_LEN:] serialize_report_header(report, REP_MARKER_DATA, session_id)
target_data = report[_REP_HEADER_LEN:]
first = False
def encode_session_open_message(session_id, target): def encode_session_open_message(session_id, target):