|
|
|
@ -1,11 +1,13 @@
|
|
|
|
|
import ustruct # pyright: ignore[reportMissingModuleSource]
|
|
|
|
|
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
|
|
|
|
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
|
|
|
|
|
from ubinascii import hexlify
|
|
|
|
|
|
|
|
|
|
import usb
|
|
|
|
|
from storage import cache_thp
|
|
|
|
|
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
|
|
|
|
|
from trezor import loop, protobuf, utils
|
|
|
|
|
from trezor.wire.thp import thp_messages
|
|
|
|
|
|
|
|
|
|
from ..protocol_common import Context
|
|
|
|
|
from . import ChannelState, SessionState, checksum
|
|
|
|
@ -51,39 +53,49 @@ class ChannelContext(Context):
|
|
|
|
|
self.sessions = load_cached_sessions(self)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
|
|
|
|
|
def create_new_channel(
|
|
|
|
|
cls, iface: WireInterface, buffer: utils.BufferType
|
|
|
|
|
) -> "ChannelContext":
|
|
|
|
|
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
|
|
|
|
|
return cls(channel_cache)
|
|
|
|
|
r = cls(channel_cache)
|
|
|
|
|
r.set_buffer(buffer)
|
|
|
|
|
r.set_channel_state(ChannelState.TH1)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
# ACCESS TO CHANNEL_DATA
|
|
|
|
|
|
|
|
|
|
def get_channel_state(self) -> ChannelState:
|
|
|
|
|
def get_channel_state(self) -> int:
|
|
|
|
|
state = int.from_bytes(self.channel_cache.state, "big")
|
|
|
|
|
return ChannelState(state)
|
|
|
|
|
print("get_ch_state", state)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def set_channel_state(self, state: ChannelState) -> None:
|
|
|
|
|
self.channel_cache.state = bytearray(state.value.to_bytes(1, "big"))
|
|
|
|
|
print("set_ch_state", int.from_bytes(state.to_bytes(1, "big"), "big"))
|
|
|
|
|
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
|
|
|
|
|
|
|
|
|
def set_buffer(self, buffer: utils.BufferType) -> None:
|
|
|
|
|
self.buffer = buffer
|
|
|
|
|
print("set buffer channel", type(self.buffer))
|
|
|
|
|
|
|
|
|
|
# CALLED BY THP_MAIN_LOOP
|
|
|
|
|
|
|
|
|
|
async def receive_packet(self, packet: utils.BufferType):
|
|
|
|
|
print("receive packet")
|
|
|
|
|
ctrl_byte = packet[0]
|
|
|
|
|
if _is_ctrl_byte_continuation(ctrl_byte):
|
|
|
|
|
await self._handle_cont_packet(packet)
|
|
|
|
|
else:
|
|
|
|
|
await self._handle_init_packet(packet)
|
|
|
|
|
|
|
|
|
|
if self.expected_payload_length == self.bytes_read:
|
|
|
|
|
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
|
|
|
|
self._finish_message()
|
|
|
|
|
await self._handle_completed_message()
|
|
|
|
|
|
|
|
|
|
async def _handle_init_packet(self, packet):
|
|
|
|
|
async def _handle_init_packet(self, packet: utils.BufferType):
|
|
|
|
|
print("handle_init_packet")
|
|
|
|
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
|
|
|
|
self.expected_payload_length = payload_length
|
|
|
|
|
packet_payload = packet[5:]
|
|
|
|
|
|
|
|
|
|
# If the channel does not "own" the buffer lock, decrypt first packet
|
|
|
|
|
# TODO do it only when needed!
|
|
|
|
|
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
|
|
|
@ -103,20 +115,33 @@ class ChannelContext(Context):
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
# TODO use small buffer
|
|
|
|
|
|
|
|
|
|
# TODO for now, we create a new big buffer every time. It should be changed
|
|
|
|
|
self.buffer = _get_buffer_for_payload(payload_length, packet)
|
|
|
|
|
|
|
|
|
|
print("self.buffer2")
|
|
|
|
|
try:
|
|
|
|
|
# TODO for now, we create a new big buffer every time. It should be changed
|
|
|
|
|
self.buffer: utils.BufferType = _get_buffer_for_message(
|
|
|
|
|
payload_length, self.buffer
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
print("payload len", payload_length)
|
|
|
|
|
print("self.buffer", self.buffer)
|
|
|
|
|
print("self.buuffer.type", type(self.buffer))
|
|
|
|
|
print("len", len(self.buffer))
|
|
|
|
|
await self._buffer_packet_data(self.buffer, packet, 0)
|
|
|
|
|
print("end init")
|
|
|
|
|
|
|
|
|
|
async def _handle_cont_packet(self, packet):
|
|
|
|
|
async def _handle_cont_packet(self, packet: utils.BufferType):
|
|
|
|
|
print("cont")
|
|
|
|
|
if not self.is_cont_packet_expected:
|
|
|
|
|
return # Continuation packet is not expected, ignoring
|
|
|
|
|
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
|
|
|
|
|
|
|
|
|
async def _handle_completed_message(self):
|
|
|
|
|
print("handling completed message")
|
|
|
|
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
|
|
|
|
msg_len = payload_length + INIT_DATA_OFFSET
|
|
|
|
|
print("checksum check")
|
|
|
|
|
printBytes(self.buffer)
|
|
|
|
|
if not checksum.is_valid(
|
|
|
|
|
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
|
|
|
|
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
|
|
|
@ -124,7 +149,7 @@ class ChannelContext(Context):
|
|
|
|
|
# checksum is not valid -> ignore message
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
print("sync bit")
|
|
|
|
|
sync_bit = (ctrl_byte & 0x10) >> 4
|
|
|
|
|
if _is_ctrl_byte_ack(ctrl_byte):
|
|
|
|
|
self._handle_received_ACK(sync_bit)
|
|
|
|
@ -132,6 +157,7 @@ class ChannelContext(Context):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
state = self.get_channel_state()
|
|
|
|
|
_print_state(state)
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
|
if not _is_ctrl_byte_handshake_init:
|
|
|
|
@ -152,15 +178,26 @@ class ChannelContext(Context):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
|
|
|
|
print("message is not encrypted. Ignoring")
|
|
|
|
|
# TODO ignore message
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
self._decrypt_buffer()
|
|
|
|
|
session_id, message_type = ustruct.unpack(
|
|
|
|
|
">BH", self.buffer[INIT_DATA_OFFSET:]
|
|
|
|
|
)
|
|
|
|
|
if session_id == 0:
|
|
|
|
|
try:
|
|
|
|
|
message = thp_messages.decode_message(
|
|
|
|
|
self.buffer[INIT_DATA_OFFSET + 3 :], message_type
|
|
|
|
|
)
|
|
|
|
|
print(message)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
|
|
|
|
|
# TODO not finished
|
|
|
|
|
|
|
|
|
|
if session_id not in self.sessions:
|
|
|
|
|
raise Exception("Unalloacted session")
|
|
|
|
|
|
|
|
|
@ -174,6 +211,7 @@ class ChannelContext(Context):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
|
print("th2 branche")
|
|
|
|
|
host_encrypted_static_pubkey = self.buffer[
|
|
|
|
|
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
|
|
|
|
]
|
|
|
|
@ -187,6 +225,7 @@ class ChannelContext(Context):
|
|
|
|
|
# TODO send ack in response
|
|
|
|
|
# TODO send hanshake completion response
|
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
|
print("end completed message")
|
|
|
|
|
|
|
|
|
|
def _decrypt(self, payload) -> bytes:
|
|
|
|
|
return payload # TODO add decryption process
|
|
|
|
@ -196,9 +235,10 @@ class ChannelContext(Context):
|
|
|
|
|
# TODO decode buffer in place
|
|
|
|
|
|
|
|
|
|
async def _buffer_packet_data(
|
|
|
|
|
self, payload_buffer, packet: utils.BufferType, offset
|
|
|
|
|
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
|
|
|
|
):
|
|
|
|
|
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
|
|
|
|
print("bytes, read:", self.bytes_read)
|
|
|
|
|
|
|
|
|
|
def _finish_message(self):
|
|
|
|
|
self.bytes_read = 0
|
|
|
|
@ -221,7 +261,8 @@ class ChannelContext(Context):
|
|
|
|
|
# OTHER
|
|
|
|
|
|
|
|
|
|
def _todo_clear_buffer(self):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
# TODO Buffer clearing not implemented
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# TODO add debug logging to ACK handling
|
|
|
|
|
def _handle_received_ACK(self, sync_bit: int) -> None:
|
|
|
|
@ -273,22 +314,26 @@ def _encode_iface(iface: WireInterface) -> bytes:
|
|
|
|
|
raise Exception("Unknown WireInterface")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_buffer_for_payload(
|
|
|
|
|
def _get_buffer_for_message(
|
|
|
|
|
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
|
|
|
|
) -> utils.BufferType:
|
|
|
|
|
if payload_length > max_length:
|
|
|
|
|
length = payload_length + INIT_DATA_OFFSET
|
|
|
|
|
print("length", length)
|
|
|
|
|
print("existing buffer type", type(existing_buffer))
|
|
|
|
|
if length > max_length:
|
|
|
|
|
raise ThpError("Message too large")
|
|
|
|
|
if payload_length > len(existing_buffer):
|
|
|
|
|
|
|
|
|
|
if length > len(existing_buffer):
|
|
|
|
|
# allocate a new buffer to fit the message
|
|
|
|
|
try:
|
|
|
|
|
payload: utils.BufferType = bytearray(payload_length)
|
|
|
|
|
payload: utils.BufferType = bytearray(length)
|
|
|
|
|
except MemoryError:
|
|
|
|
|
payload = bytearray(REPORT_LENGTH)
|
|
|
|
|
raise ThpError("Message too large")
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
# reuse a part of the supplied buffer
|
|
|
|
|
return memoryview(existing_buffer)[:payload_length]
|
|
|
|
|
return memoryview(existing_buffer)[:length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
|
|
|
@ -305,3 +350,33 @@ def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
|
|
|
|
return ctrl_byte & 0xEF == ACK_MESSAGE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _print_state(cs: int) -> None:
|
|
|
|
|
if cs == ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
print("state: encrypted transport")
|
|
|
|
|
elif cs == ChannelState.TH1:
|
|
|
|
|
print("state: th1")
|
|
|
|
|
elif cs == ChannelState.TH2:
|
|
|
|
|
print("state: th2")
|
|
|
|
|
elif cs == ChannelState.TP1:
|
|
|
|
|
print("state: tp1")
|
|
|
|
|
elif cs == ChannelState.TP2:
|
|
|
|
|
print("state: tp2")
|
|
|
|
|
elif cs == ChannelState.TP3:
|
|
|
|
|
print("state: tp3")
|
|
|
|
|
elif cs == ChannelState.TP4:
|
|
|
|
|
print("state: tp4")
|
|
|
|
|
elif cs == ChannelState.TP5:
|
|
|
|
|
print("state: tp5")
|
|
|
|
|
elif cs == ChannelState.UNALLOCATED:
|
|
|
|
|
print("state: unallocated")
|
|
|
|
|
elif cs == ChannelState.UNAUTHENTICATED:
|
|
|
|
|
print("state: unauthenticated")
|
|
|
|
|
else:
|
|
|
|
|
print(cs)
|
|
|
|
|
print("state: <not implemented printout>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def printBytes(a):
|
|
|
|
|
print(hexlify(a).decode("utf-8"))
|
|
|
|
|