mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
fix(core): fix continuation packet ignoring, unify logging
This commit is contained in:
parent
f31d8f59ce
commit
30da02b0f2
@ -44,6 +44,7 @@ MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
class Channel(Context):
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
print("channel.__init__")
|
||||
iface = _decode_iface(channel_cache.iface)
|
||||
super().__init__(iface, channel_cache.channel_id)
|
||||
self.channel_cache = channel_cache
|
||||
@ -70,34 +71,39 @@ class Channel(Context):
|
||||
|
||||
def get_channel_state(self) -> int:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
print("get_ch_state", state)
|
||||
print("channel.get_ch_state:", state)
|
||||
return state
|
||||
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
print("set_ch_state", int.from_bytes(state.to_bytes(1, "big"), "big"))
|
||||
print("channel.set_ch_state:", int.from_bytes(state.to_bytes(1, "big"), "big"))
|
||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
|
||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
print("set buffer channel", type(self.buffer))
|
||||
print("channel.set_buffer:", type(self.buffer))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
async def receive_packet(self, packet: utils.BufferType):
|
||||
print("receive packet")
|
||||
print("channel.receive_packet")
|
||||
ctrl_byte = packet[0]
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
print("receive packet", self.expected_payload_length, self.bytes_read)
|
||||
printBytes(self.buffer)
|
||||
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
||||
self._finish_message()
|
||||
await self._handle_completed_message()
|
||||
elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read:
|
||||
self.is_cont_packet_expected = True
|
||||
else:
|
||||
raise ThpError(
|
||||
"Read more bytes than is the expected length of the message, this should not happen!"
|
||||
)
|
||||
|
||||
async def _handle_init_packet(self, packet: utils.BufferType):
|
||||
print("handle_init_packet")
|
||||
print("channel._handle_init_packet")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
self.expected_payload_length = payload_length
|
||||
packet_payload = packet[5:]
|
||||
@ -127,20 +133,19 @@ class Channel(Context):
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("payload len", payload_length)
|
||||
print("len", len(self.buffer))
|
||||
print("channel._handle_init_packet - payload len", payload_length)
|
||||
print("channel._handle_init_packet - buffer len", len(self.buffer))
|
||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
||||
print("end init")
|
||||
print("channel._handle_init_packet - end")
|
||||
|
||||
async def _handle_cont_packet(self, packet: utils.BufferType):
|
||||
print("cont")
|
||||
print("channel._handle_cont_packet")
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
||||
|
||||
async def _handle_completed_message(self) -> None:
|
||||
print("handling completed message")
|
||||
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
|
||||
print("channel._handle_completed_message")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
message_length = payload_length + INIT_DATA_OFFSET
|
||||
|
||||
@ -148,7 +153,7 @@ class Channel(Context):
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
print("sync bit:", sync_bit)
|
||||
print("channel._handle_completed_message - sync bit of message:", sync_bit)
|
||||
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
@ -173,10 +178,10 @@ class Channel(Context):
|
||||
await self._handle_valid_message(
|
||||
payload_length, message_length, ctrl_byte, sync_bit
|
||||
)
|
||||
print("end handle completed message")
|
||||
print("channel._handle_completed_message - end")
|
||||
|
||||
def _check_checksum(self, message_length: int):
|
||||
print("checksum check")
|
||||
print("channel._check_checksum")
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||
data=self.buffer[: message_length - CHECKSUM_LENGTH],
|
||||
@ -229,7 +234,7 @@ class Channel(Context):
|
||||
return
|
||||
|
||||
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
|
||||
print("th2 branche")
|
||||
print("channel._handle_state_TH2")
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
@ -253,10 +258,11 @@ class Channel(Context):
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
||||
print("channel._handle_state_ENCRYPTED_TRANSPORT")
|
||||
self._decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
self._handle_channel_comms(message_length, message_type)
|
||||
self._handle_channel_message(message_length, message_type)
|
||||
return
|
||||
|
||||
if session_id not in self.sessions:
|
||||
@ -273,29 +279,25 @@ class Channel(Context):
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_channel_comms(self, message_length: int, message_type: int) -> None:
|
||||
try:
|
||||
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
|
||||
def _handle_channel_message(self, message_length: int, message_type: int) -> None:
|
||||
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
|
||||
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print(message)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("passphrase:", message.passphrase)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
if message.passphrase is not None:
|
||||
self.create_new_session(message.passphrase)
|
||||
else:
|
||||
self.create_new_session()
|
||||
# TODO reuse existing buffer and compute size dynamically
|
||||
bufferrone = bytearray(2)
|
||||
message_size: int = thp_messages.get_new_session_message(bufferrone)
|
||||
print(message_size) # TODO adjust
|
||||
loop.schedule(self.write_and_encrypt(bufferrone))
|
||||
except Exception as e:
|
||||
print("Proč??")
|
||||
print(e)
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print("channel._handle_channel_message:", message)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("channel._handle_channel_message - passphrase:", message.passphrase)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
if message.passphrase is not None:
|
||||
self.create_new_session(message.passphrase)
|
||||
else:
|
||||
self.create_new_session()
|
||||
# TODO reuse existing buffer and compute size dynamically
|
||||
bufferrone = bytearray(2)
|
||||
message_size: int = thp_messages.get_new_session_message(bufferrone)
|
||||
print(message_size) # TODO adjust
|
||||
loop.schedule(self.write_and_encrypt(bufferrone))
|
||||
# TODO not finished
|
||||
|
||||
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
||||
@ -315,7 +317,7 @@ class Channel(Context):
|
||||
)
|
||||
|
||||
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
|
||||
print("\n Encrypting ")
|
||||
print("channel._encrypt")
|
||||
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
|
||||
new_buffer = bytearray(min_required_length)
|
||||
@ -334,7 +336,6 @@ class Channel(Context):
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
print("bytes, read:", self.bytes_read)
|
||||
|
||||
def _finish_message(self):
|
||||
self.bytes_read = 0
|
||||
@ -366,7 +367,7 @@ class Channel(Context):
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
print("write")
|
||||
print("channel.write")
|
||||
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
|
||||
@ -381,7 +382,7 @@ class Channel(Context):
|
||||
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length]))
|
||||
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
print("write loop before while")
|
||||
print("channel._write_encrypted_payload_loop")
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
||||
@ -395,9 +396,9 @@ class Channel(Context):
|
||||
THP.sync_set_can_send_message(self.channel_cache, False)
|
||||
while True:
|
||||
print(
|
||||
"write encrypted payload loop - start, sync_bit:",
|
||||
"channel._write_encrypted_payload_loop - loop start, sync_bit:",
|
||||
header.ctrl_byte & 0x10,
|
||||
" send_sync_bit:",
|
||||
" sync_send_bit:",
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_payload_to_wire(header, payload, payload_len)
|
||||
@ -411,7 +412,7 @@ class Channel(Context):
|
||||
async def _write_payload_to_wire(
|
||||
self, header: InitHeader, payload: bytes, payload_len: int
|
||||
):
|
||||
print("write payload to wire:")
|
||||
print("chanel._write_payload_to_wire")
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
@ -468,13 +469,16 @@ class Channel(Context):
|
||||
self,
|
||||
passphrase="",
|
||||
) -> None: # TODO change it to output session data
|
||||
print("create new session")
|
||||
print("channel.create_new_session")
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
session = SessionContext.create_new_session(self)
|
||||
self.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
print("new session created. Session id:", session.session_id)
|
||||
print(
|
||||
"channel.create_new_session - new session created. Session id:",
|
||||
session.session_id,
|
||||
)
|
||||
print(self.sessions)
|
||||
|
||||
def _todo_clear_buffer(self):
|
||||
@ -484,8 +488,10 @@ class Channel(Context):
|
||||
# TODO add debug logging to ACK handling
|
||||
def _handle_received_ACK(self, sync_bit: int) -> None:
|
||||
if self._ack_is_not_expected():
|
||||
print("channel._handle_received_ACK - ack is not expected")
|
||||
return
|
||||
if self._ack_has_incorrect_sync_bit(sync_bit):
|
||||
print("channel._handle_received_ACK - ack has incorrect sync bit")
|
||||
return
|
||||
|
||||
if self.waiting_for_ack_timeout is not None:
|
||||
@ -535,8 +541,10 @@ def _get_buffer_for_message(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
length = payload_length + INIT_DATA_OFFSET
|
||||
print("length", length)
|
||||
print("existing buffer type", type(existing_buffer))
|
||||
print("channel._get_buffer_for_message - length", length)
|
||||
print(
|
||||
"channel._get_buffer_for_message - existing buffer type", type(existing_buffer)
|
||||
)
|
||||
if length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
|
||||
|
@ -139,11 +139,11 @@ class SessionContext(Context):
|
||||
|
||||
|
||||
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
|
||||
print("start loading sessions from cache")
|
||||
print("session_context.load_cached_sessions")
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
print(
|
||||
"loaded a total of ",
|
||||
"session_context.load_cached_sessions - loaded a total of ",
|
||||
len(cached_sessions),
|
||||
"sessions from cache",
|
||||
)
|
||||
|
@ -45,7 +45,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("main loop")
|
||||
print("thp_v1.thp_main_loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
@ -68,7 +68,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
print("packet type in loop:", type(packet))
|
||||
await channel.receive_packet(packet)
|
||||
continue
|
||||
await _handle_unallocated(iface, cid)
|
||||
|
Loading…
Reference in New Issue
Block a user