1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-11 19:38:48 +00:00

autopep8 and minor cosmetics

This commit is contained in:
Jan Pochyla 2016-09-28 23:28:53 +02:00
parent 976b14a5b8
commit c93133026e
3 changed files with 31 additions and 24 deletions

View File

@ -12,9 +12,7 @@ from .wire_codec import \
decode_wire_stream, encode_wire_message, \ decode_wire_stream, encode_wire_message, \
encode_session_open_message, encode_session_close_message encode_session_open_message, encode_session_close_message
from .wire_codec_v1 import \ from .wire_codec_v1 import \
SESSION_V1, \ SESSION_V1, decode_wire_v1_stream, encode_wire_v1_message
decode_wire_v1_stream, \
encode_wire_v1_message
_session_handlers = {} # session id -> generator _session_handlers = {} # session id -> generator
_workflow_genfuncs = {} # wire type -> (generator function, args) _workflow_genfuncs = {} # wire type -> (generator function, args)
@ -138,8 +136,8 @@ async def monitor_workflow(workflow, session_id):
finally: finally:
if session_id in _opened_sessions: if session_id in _opened_sessions:
if session_id == SESSION_V1: if session_id == SESSION_V1:
wire_decoder = decode_wire_v1_stream(_handle_registered_type, wire_decoder = decode_wire_v1_stream(
SESSION_V1) _handle_registered_type, session_id)
else: else:
wire_decoder = decode_wire_stream( wire_decoder = decode_wire_stream(
_handle_registered_type, session_id) _handle_registered_type, session_id)

View File

@ -83,12 +83,8 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
target.send(None) target.send(None)
checksum = 0 # crc32 checksum = 0 # crc32
nreports = 1
while data_len > 0: while data_len > 0:
if nreports > 1:
data_tail = memoryview((yield)) # read next report
nreports += 1
data_chunk = data_tail[:data_len] # slice off the garbage at the end data_chunk = data_tail[:data_len] # slice off the garbage at the end
data_tail = data_tail[len(data_chunk):] # slice off what we have read data_tail = data_tail[len(data_chunk):] # slice off what we have read
@ -97,6 +93,9 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
checksum = ubinascii.crc32(data_chunk, checksum) checksum = ubinascii.crc32(data_chunk, checksum)
if data_len > 0:
data_tail = memoryview((yield)) # read next report
msg_footer = data_tail[:_MSG_FOOTER_LEN] msg_footer = data_tail[:_MSG_FOOTER_LEN]
if len(msg_footer) < _MSG_FOOTER_LEN: if len(msg_footer) < _MSG_FOOTER_LEN:
data_tail = yield # read report with the rest of checksum data_tail = yield # read report with the rest of checksum
@ -137,7 +136,7 @@ def encode_wire_message(msg_type, msg_data, session_id, target):
msg_footer = None msg_footer = None
continue continue
# FIXME: Optimize speed # FIXME: optimize speed
x = 0 x = 0
to_fill = len(target_data) to_fill = len(target_data)
while x < to_fill: while x < to_fill:

View File

@ -1,28 +1,36 @@
import ustruct import ustruct
SESSION_V1 = const(0) SESSION_V1 = const(0)
REP_MARKER_V1 = const(63) # ord('?) REP_MARKER_V1 = const(63) # ord('?)
REP_MARKER_V1_LEN = const(1) # len('?') REP_MARKER_V1_LEN = const(1) # len('?')
_MSG_HEADER_MAGIC = const(35) # org('#') _REP_LEN = const(64)
_MSG_HEADER_V1 = '>BBHL' # wire type, data length _MSG_HEADER_MAGIC = const(35) # org('#')
_MSG_HEADER_V1 = '>BBHL' # magic, magic, wire type, data length
_MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1) _MSG_HEADER_V1_LEN = ustruct.calcsize(_MSG_HEADER_V1)
def detect_v1(data): def detect_v1(data):
return (data[0] == REP_MARKER_V1) return (data[0] == REP_MARKER_V1)
def parse_report_v1(data): def parse_report_v1(data):
return None, SESSION_V1, data[1:] return None, SESSION_V1, data[1:]
def parse_message(data): def parse_message(data):
magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data) magic1, magic2, msg_type, data_len = ustruct.unpack(_MSG_HEADER_V1, data)
if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC: if magic1 != _MSG_HEADER_MAGIC or magic2 != _MSG_HEADER_MAGIC:
raise Exception("Corrupted magic bytes") raise Exception('Corrupted magic bytes')
return msg_type, data_len, data[_MSG_HEADER_V1_LEN:] return msg_type, data_len, data[_MSG_HEADER_V1_LEN:]
def serialize_message_header(data, msg_type, msg_len): def serialize_message_header(data, msg_type, msg_len):
ustruct.pack_into(_MSG_HEADER_V1, data, REP_MARKER_V1_LEN, _MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len) ustruct.pack_into(
_MSG_HEADER_V1, data, REP_MARKER_V1_LEN,
_MSG_HEADER_MAGIC, _MSG_HEADER_MAGIC, msg_type, msg_len)
def decode_wire_v1_stream(genfunc, session_id, *args): def decode_wire_v1_stream(genfunc, session_id, *args):
'''Decode a v1 wire message from the report data and stream it to target. '''Decode a v1 wire message from the report data and stream it to target.
@ -32,8 +40,8 @@ Sends (msg_type, data_len) to target, followed by data chunks.
Throws EOFError after last data chunk, in case of valid checksum. Throws EOFError after last data chunk, in case of valid checksum.
Throws MessageChecksumError to target if data doesn't match the checksum. Throws MessageChecksumError to target if data doesn't match the checksum.
''' '''
message = yield # read first report message = yield # read first report
msg_type, data_len, data = parse_message(message) msg_type, data_len, data = parse_message(message)
print(msg_type, data_len, bytes(data)) print(msg_type, data_len, bytes(data))
@ -48,17 +56,18 @@ Throws MessageChecksumError to target if data doesn't match the checksum.
target.send(data_chunk) target.send(data_chunk)
if data_len > 0: if data_len > 0:
data = yield # First next record data = yield # read next report
target.throw(EOFError()) target.throw(EOFError())
def encode_wire_v1_message(msg_type, msg_data, target): def encode_wire_v1_message(msg_type, msg_data, target):
report = memoryview(bytearray(64)) # Maximum report length report = memoryview(bytearray(_REP_LEN))
report[0] = REP_MARKER_V1 # Put report marker report[0] = REP_MARKER_V1
serialize_message_header(report, msg_type, len(msg_data)) serialize_message_header(report, msg_type, len(msg_data))
source_data = memoryview(msg_data) source_data = memoryview(msg_data)
target_data = report[REP_MARKER_V1_LEN+_MSG_HEADER_V1_LEN:] target_data = report[REP_MARKER_V1_LEN + _MSG_HEADER_V1_LEN:]
while True: while True:
# move as much as possible from source to target # move as much as possible from source to target
@ -67,7 +76,7 @@ def encode_wire_v1_message(msg_type, msg_data, target):
source_data = source_data[n:] source_data = source_data[n:]
target_data = target_data[n:] target_data = target_data[n:]
# FIXME: Optimize speed # FIXME: optimize speed
x = 0 x = 0
to_fill = len(target_data) to_fill = len(target_data)
while x < to_fill: while x < to_fill:
@ -79,4 +88,5 @@ def encode_wire_v1_message(msg_type, msg_data, target):
if not source_data: if not source_data:
break break
target_data = report[REP_MARKER_V1_LEN:] # reset to skip the magic, not the whole header anymore
target_data = report[REP_MARKER_V1_LEN:]