1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

fix(core): fix handshake reading, session creation part 1

This commit is contained in:
M1nd3r 2024-03-27 14:16:09 +01:00
parent 37547b19da
commit 912c85e21e
4 changed files with 114 additions and 23 deletions

View File

@ -132,6 +132,7 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
print("Exception raised:", exc)
async def handle_session(

View File

@ -1,11 +1,13 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from ubinascii import hexlify
import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
from trezor import loop, protobuf, utils
from trezor.wire.thp import thp_messages
from ..protocol_common import Context
from . import ChannelState, SessionState, checksum
@ -51,39 +53,49 @@ class ChannelContext(Context):
self.sessions = load_cached_sessions(self)
@classmethod
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
def create_new_channel(
cls, iface: WireInterface, buffer: utils.BufferType
) -> "ChannelContext":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
return cls(channel_cache)
r = cls(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
# ACCESS TO CHANNEL_DATA
def get_channel_state(self) -> ChannelState:
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
return ChannelState(state)
print("get_ch_state", state)
return state
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.value.to_bytes(1, "big"))
print("set_ch_state", int.from_bytes(state.to_bytes(1, "big"), "big"))
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
print("set buffer channel", type(self.buffer))
# CALLED BY THP_MAIN_LOOP
async def receive_packet(self, packet: utils.BufferType):
print("receive packet")
ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
if self.expected_payload_length == self.bytes_read:
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message()
await self._handle_completed_message()
async def _handle_init_packet(self, packet):
async def _handle_init_packet(self, packet: utils.BufferType):
print("handle_init_packet")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length
packet_payload = packet[5:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
@ -103,20 +115,33 @@ class ChannelContext(Context):
else:
pass
# TODO use small buffer
# TODO for now, we create a new big buffer every time. It should be changed
self.buffer = _get_buffer_for_payload(payload_length, packet)
print("self.buffer2")
try:
# TODO for now, we create a new big buffer every time. It should be changed
self.buffer: utils.BufferType = _get_buffer_for_message(
payload_length, self.buffer
)
except Exception as e:
print(e)
print("payload len", payload_length)
print("self.buffer", self.buffer)
print("self.buuffer.type", type(self.buffer))
print("len", len(self.buffer))
await self._buffer_packet_data(self.buffer, packet, 0)
print("end init")
async def _handle_cont_packet(self, packet):
async def _handle_cont_packet(self, packet: utils.BufferType):
print("cont")
if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self):
print("handling completed message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + INIT_DATA_OFFSET
print("checksum check")
printBytes(self.buffer)
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
@ -124,7 +149,7 @@ class ChannelContext(Context):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
print("sync bit")
sync_bit = (ctrl_byte & 0x10) >> 4
if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit)
@ -132,6 +157,7 @@ class ChannelContext(Context):
return
state = self.get_channel_state()
_print_state(state)
if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init:
@ -152,15 +178,26 @@ class ChannelContext(Context):
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
print("message is not encrypted. Ignoring")
# TODO ignore message
self._todo_clear_buffer()
return
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(
">BH", self.buffer[INIT_DATA_OFFSET:]
)
if session_id == 0:
try:
message = thp_messages.decode_message(
self.buffer[INIT_DATA_OFFSET + 3 :], message_type
)
print(message)
except Exception as e:
print(e)
# TODO not finished
if session_id not in self.sessions:
raise Exception("Unalloacted session")
@ -174,6 +211,7 @@ class ChannelContext(Context):
)
if state is ChannelState.TH2:
print("th2 branche")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
@ -187,6 +225,7 @@ class ChannelContext(Context):
# TODO send ack in response
# TODO send hanshake completion response
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
print("end completed message")
def _decrypt(self, payload) -> bytes:
return payload # TODO add decryption process
@ -196,9 +235,10 @@ class ChannelContext(Context):
# TODO decode buffer in place
async def _buffer_packet_data(
self, payload_buffer, packet: utils.BufferType, offset
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
):
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
print("bytes, read:", self.bytes_read)
def _finish_message(self):
self.bytes_read = 0
@ -221,7 +261,8 @@ class ChannelContext(Context):
# OTHER
def _todo_clear_buffer(self):
raise NotImplementedError()
# TODO Buffer clearing not implemented
pass
# TODO add debug logging to ACK handling
def _handle_received_ACK(self, sync_bit: int) -> None:
@ -273,22 +314,26 @@ def _encode_iface(iface: WireInterface) -> bytes:
raise Exception("Unknown WireInterface")
def _get_buffer_for_payload(
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
length = payload_length + INIT_DATA_OFFSET
print("length", length)
print("existing buffer type", type(existing_buffer))
if length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(payload_length)
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:payload_length]
return memoryview(existing_buffer)[:length]
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
@ -305,3 +350,33 @@ def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def _print_state(cs: int) -> None:
if cs == ChannelState.ENCRYPTED_TRANSPORT:
print("state: encrypted transport")
elif cs == ChannelState.TH1:
print("state: th1")
elif cs == ChannelState.TH2:
print("state: th2")
elif cs == ChannelState.TP1:
print("state: tp1")
elif cs == ChannelState.TP2:
print("state: tp2")
elif cs == ChannelState.TP3:
print("state: tp3")
elif cs == ChannelState.TP4:
print("state: tp4")
elif cs == ChannelState.TP5:
print("state: tp5")
elif cs == ChannelState.UNALLOCATED:
print("state: unallocated")
elif cs == ChannelState.UNAUTHENTICATED:
print("state: unauthenticated")
else:
print(cs)
print("state: <not implemented printout>")
def printBytes(a):
print(hexlify(a).decode("utf-8"))

View File

@ -1,7 +1,9 @@
import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf
from .. import message_handler
from ..protocol_common import Message
CODEC_V1 = 0x3F
@ -73,3 +75,12 @@ def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes:
def get_error_unallocated_channel() -> bytes:
return _ERROR_UNALLOCATED_SESSION
def get_handshake_init_response() -> bytes:
return b"\x00" # TODO implement
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
expected_type = protobuf.type_for_wire(msg_type)
return message_handler.wrap_protobuf_load(buffer, expected_type)

View File

@ -54,6 +54,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
def set_buffer(buffer):
global _BUFFER
_BUFFER = buffer
print("setbuffer,", type(_BUFFER))
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
@ -64,6 +65,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
print("main loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
@ -86,6 +88,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
print("packet type in loop:", type(packet))
await channel.receive_packet(packet)
continue
@ -330,6 +333,7 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
async def _handle_broadcast(
iface: WireInterface, ctrl_byte, packet
) -> MessageWithId | None:
global _BUFFER
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
if __debug__:
@ -342,7 +346,7 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
new_context: ChannelContext = ChannelContext.create_new_channel(iface)
new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_context.channel_id, "big")
_CHANNEL_CONTEXTS[cid] = new_context