mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 09:28:13 +00:00
use memoryviews in wire_codec
This commit is contained in:
parent
7570977cc4
commit
64746d01b4
@ -35,7 +35,6 @@ _MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
|
|||||||
|
|
||||||
def parse_report(data):
|
def parse_report(data):
|
||||||
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
||||||
# TODO: handle v1 protocol
|
|
||||||
return marker, session_id, data[_REP_HEADER_LEN:]
|
return marker, session_id, data[_REP_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ Sends (msg_type, data_len) to target, followed by data chunks.
|
|||||||
Throws EOFError after last data chunk, in case of valid checksum.
|
Throws EOFError after last data chunk, in case of valid checksum.
|
||||||
Throws MessageChecksumError to target if data doesn't match the checksum.
|
Throws MessageChecksumError to target if data doesn't match the checksum.
|
||||||
'''
|
'''
|
||||||
message = yield # read first report
|
message = memoryview((yield)) # read first report
|
||||||
msg_type, data_len, data_tail = parse_message(message)
|
msg_type, data_len, data_tail = parse_message(message)
|
||||||
|
|
||||||
target = genfunc(msg_type, data_len, session_id, *args)
|
target = genfunc(msg_type, data_len, session_id, *args)
|
||||||
@ -86,11 +85,9 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
|
|||||||
checksum = 0 # crc32
|
checksum = 0 # crc32
|
||||||
nreports = 1
|
nreports = 1
|
||||||
|
|
||||||
compute_checksum = hasattr(ubinascii, 'crc32')
|
|
||||||
|
|
||||||
while data_len > 0:
|
while data_len > 0:
|
||||||
if nreports > 1:
|
if nreports > 1:
|
||||||
data_tail = yield # read next report
|
data_tail = memoryview((yield)) # read next report
|
||||||
nreports += 1
|
nreports += 1
|
||||||
|
|
||||||
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
||||||
@ -98,46 +95,35 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
|
|||||||
data_len -= len(data_chunk)
|
data_len -= len(data_chunk)
|
||||||
target.send(data_chunk)
|
target.send(data_chunk)
|
||||||
|
|
||||||
if compute_checksum:
|
checksum = ubinascii.crc32(data_chunk, checksum)
|
||||||
checksum = ubinascii.crc32(data_chunk, checksum) & 0xffffffff
|
|
||||||
|
|
||||||
msg_footer = data_tail[:_MSG_FOOTER_LEN]
|
msg_footer = data_tail[:_MSG_FOOTER_LEN]
|
||||||
if len(msg_footer) < _MSG_FOOTER_LEN:
|
if len(msg_footer) < _MSG_FOOTER_LEN:
|
||||||
data_tail = yield # read report with the rest of checksum
|
data_tail = yield # read report with the rest of checksum
|
||||||
msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
|
msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
|
||||||
|
|
||||||
if compute_checksum:
|
|
||||||
data_checksum, = parse_message_footer(msg_footer)
|
data_checksum, = parse_message_footer(msg_footer)
|
||||||
else:
|
|
||||||
data_checksum = checksum
|
|
||||||
if data_checksum != checksum:
|
if data_checksum != checksum:
|
||||||
target.throw(MessageChecksumError(
|
target.throw(MessageChecksumError((checksum, data_checksum)))
|
||||||
'Message checksum mismatch, expected %d, received %d' % (checksum, data_checksum)))
|
|
||||||
else:
|
else:
|
||||||
target.throw(EOFError())
|
target.throw(EOFError())
|
||||||
|
|
||||||
|
|
||||||
def encode_wire_message(msg_type, msg_data, session_id, target):
|
def encode_wire_message(msg_type, msg_data, session_id, target):
|
||||||
report = bytearray(_REP_LEN)
|
report = memoryview(bytearray(_REP_LEN))
|
||||||
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
||||||
serialize_message_header(report, msg_type, len(msg_data))
|
serialize_message_header(report, msg_type, len(msg_data))
|
||||||
|
|
||||||
msg_data = memoryview(msg_data)
|
source_data = memoryview(msg_data)
|
||||||
report = memoryview(report)
|
|
||||||
|
|
||||||
source_data = msg_data
|
|
||||||
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
|
target_data = report[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
|
||||||
|
|
||||||
compute_checksum = hasattr(ubinascii, 'crc32')
|
checksum = ubinascii.crc32(msg_data)
|
||||||
|
|
||||||
if compute_checksum:
|
|
||||||
checksum = ubinascii.crc32(msg_data) & 0xffffffff
|
|
||||||
else:
|
|
||||||
checksum = 0
|
|
||||||
|
|
||||||
msg_footer = bytearray(_MSG_FOOTER_LEN)
|
msg_footer = bytearray(_MSG_FOOTER_LEN)
|
||||||
serialize_message_footer(msg_footer, checksum)
|
serialize_message_footer(msg_footer, checksum)
|
||||||
|
|
||||||
|
first = True
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# move as much as possible from source to target
|
# move as much as possible from source to target
|
||||||
n = min(len(target_data), len(source_data))
|
n = min(len(target_data), len(source_data))
|
||||||
@ -157,8 +143,10 @@ def encode_wire_message(msg_type, msg_data, session_id, target):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# reset to skip the magic and session ID
|
# reset to skip the magic and session ID
|
||||||
|
if first:
|
||||||
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
||||||
target_data = report[_REP_HEADER_LEN:]
|
target_data = report[_REP_HEADER_LEN:]
|
||||||
|
first = False
|
||||||
|
|
||||||
|
|
||||||
def encode_session_open_message(session_id, target):
|
def encode_session_open_message(session_id, target):
|
||||||
|
Loading…
Reference in New Issue
Block a user