|
|
|
@ -68,7 +68,6 @@ class Channel(Context):
|
|
|
|
|
from trezor.wire.thp.session_context import load_cached_sessions
|
|
|
|
|
|
|
|
|
|
self.connection_context = None
|
|
|
|
|
|
|
|
|
|
self.sessions = load_cached_sessions(self)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -107,13 +106,12 @@ class Channel(Context):
|
|
|
|
|
async def receive_packet(self, packet: utils.BufferType):
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "receive_packet")
|
|
|
|
|
ctrl_byte = packet[0]
|
|
|
|
|
if _is_ctrl_byte_continuation(ctrl_byte):
|
|
|
|
|
await self._handle_cont_packet(packet)
|
|
|
|
|
else:
|
|
|
|
|
await self._handle_init_packet(packet)
|
|
|
|
|
|
|
|
|
|
await self._handle_received_packet(packet)
|
|
|
|
|
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
|
|
|
|
|
|
|
|
|
|
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
|
|
|
|
self._finish_message()
|
|
|
|
|
await self._handle_completed_message()
|
|
|
|
@ -124,7 +122,14 @@ class Channel(Context):
|
|
|
|
|
"Read more bytes than is the expected length of the message, this should not happen!"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _handle_init_packet(self, packet: utils.BufferType):
|
|
|
|
|
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
|
|
|
|
ctrl_byte = packet[0]
|
|
|
|
|
if _is_ctrl_byte_continuation(ctrl_byte):
|
|
|
|
|
await self._handle_cont_packet(packet)
|
|
|
|
|
else:
|
|
|
|
|
await self._handle_init_packet(packet)
|
|
|
|
|
|
|
|
|
|
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_init_packet")
|
|
|
|
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
|
|
|
@ -135,6 +140,16 @@ class Channel(Context):
|
|
|
|
|
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
|
|
|
|
packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
|
|
|
|
|
|
|
|
|
self._select_buffer(packet_payload, payload_length)
|
|
|
|
|
await self._buffer_packet_data(self.buffer, packet, 0)
|
|
|
|
|
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
|
|
|
|
|
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
|
|
|
|
|
|
|
|
|
|
def _select_buffer(
|
|
|
|
|
self, packet_payload: utils.BufferType, payload_length: int
|
|
|
|
|
) -> None:
|
|
|
|
|
state = self.get_channel_state()
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
@ -157,16 +172,8 @@ class Channel(Context):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.exception(__name__, e)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
|
|
|
|
|
|
|
|
|
|
await self._buffer_packet_data(self.buffer, packet, 0)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_init_packet - end")
|
|
|
|
|
|
|
|
|
|
async def _handle_cont_packet(self, packet: utils.BufferType):
|
|
|
|
|
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_cont_packet")
|
|
|
|
|
if not self.is_cont_packet_expected:
|
|
|
|
@ -257,9 +264,11 @@ class Channel(Context):
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
|
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if is_channel_state_pairing(state):
|
|
|
|
|
await self._handle_pairing(message_length)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
raise ThpError("Unimplemented channel state")
|
|
|
|
|
|
|
|
|
|
async def _handle_state_TH1(
|
|
|
|
@ -314,7 +323,7 @@ class Channel(Context):
|
|
|
|
|
"ThpHandshakeCompletionReqNoisePayload",
|
|
|
|
|
)
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
assert isinstance(noise_payload, ThpHandshakeCompletionReqNoisePayload)
|
|
|
|
|
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
|
|
|
|
|
for i in noise_payload.pairing_methods:
|
|
|
|
|
self.selected_pairing_methods.append(i)
|
|
|
|
|
if __debug__:
|
|
|
|
@ -325,6 +334,7 @@ class Channel(Context):
|
|
|
|
|
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# TODO add credential recognition
|
|
|
|
|
paired: bool = False # TODO should be output from credential check
|
|
|
|
|
|
|
|
|
|
# send hanshake completion response
|
|
|
|
@ -334,7 +344,6 @@ class Channel(Context):
|
|
|
|
|
thp_messages.get_handshake_completion_response(paired=paired),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# TODO add credential recognition
|
|
|
|
|
if paired:
|
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
|
else:
|
|
|
|
@ -343,6 +352,7 @@ class Channel(Context):
|
|
|
|
|
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
|
|
|
|
|
|
|
|
|
self._decrypt_buffer(message_length)
|
|
|
|
|
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
|
|
|
|
if session_id == 0:
|
|
|
|
@ -434,6 +444,8 @@ class Channel(Context):
|
|
|
|
|
response_message = await task
|
|
|
|
|
# TODO handle
|
|
|
|
|
await self.write(response_message)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "_handle_channel_message - end")
|
|
|
|
|
|
|
|
|
|
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
|
|
|
|
payload_buffer = bytearray(payload)
|
|
|
|
@ -536,6 +548,8 @@ class Channel(Context):
|
|
|
|
|
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "Scheduled _write_encrypted_payload_loop")
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload_loop(
|
|
|
|
|
self, ctrl_byte: int, payload: bytes
|
|
|
|
@ -574,6 +588,8 @@ class Channel(Context):
|
|
|
|
|
not workflow.tasks
|
|
|
|
|
and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
|
|
|
|
):
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "Clearing loop from channel")
|
|
|
|
|
loop.clear()
|
|
|
|
|
|
|
|
|
|
async def _wait_for_ack(self) -> None:
|
|
|
|
@ -653,10 +669,11 @@ def _get_buffer_for_message(
|
|
|
|
|
) -> utils.BufferType:
|
|
|
|
|
length = payload_length + INIT_DATA_OFFSET
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "get_buffer_for_message - length: %d", length)
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__,
|
|
|
|
|
"get_buffer_for_message - existing buffer type: %s",
|
|
|
|
|
"get_buffer_for_message - length: %d, %s %s",
|
|
|
|
|
length,
|
|
|
|
|
"existing buffer type:",
|
|
|
|
|
type(existing_buffer),
|
|
|
|
|
)
|
|
|
|
|
if length > max_length:
|
|
|
|
|