diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 82fee12e1..828017e2b 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -129,7 +129,7 @@ def _get_cid(session: SessionThpCache) -> int: return int.from_bytes(session.session_id[2:], "big") -def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache: +def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache: if len(session_id) != 4: raise ValueError("session_id must be 4 bytes long.") global _active_session_idx @@ -185,7 +185,7 @@ def get_least_recently_used_authetnicated_session_index() -> int: # The function start_session should not be used in production code. It is present only to assure compatibility with old tests. -def start_session(session_id: bytes) -> bytes: # TODO incomplete +def start_session(session_id: bytes | None) -> bytes: # TODO incomplete global _active_session_idx global _is_active_session_authenticated diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 075763149..67e0f8fcc 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -47,7 +47,7 @@ if TYPE_CHECKING: Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[[Msg], HandlerTask] - + HandlerWithSessionId = Callable[[Msg, bytes | None], HandlerTask] LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) @@ -199,7 +199,7 @@ async def _handle_single_message( async def handle_session( - iface: WireInterface, session_id: int, is_debug_session: bool = False + iface: WireInterface, session_id: bytes, is_debug_session: bool = False ) -> None: if __debug__ and is_debug_session: ctx_buffer = WIRE_BUFFER_DEBUG @@ -268,7 +268,9 @@ async def handle_session( log.exception(__name__, exc) -def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: +def _find_handler_placeholder( + iface: WireInterface, msg_type: int +) -> Handler | HandlerWithSessionId | None: """Placeholder handler lookup before a proper one is registered.""" return None diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 226edfcbb..badff42de 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -65,7 +65,7 @@ class Context: self, iface: WireInterface, buffer: bytearray, - session_id: bytearray | None = None, + session_id: bytes | None = None, ) -> None: self.iface = iface self.buffer = buffer diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index c90a6eb2c..95d27b4f1 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -8,12 +8,14 @@ if TYPE_CHECKING: class WireProtocol: - async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: + async def read_message( + self, iface: WireInterface, buffer: utils.BufferType + ) -> Message: if utils.USE_THP: return await thp_v1.read_message(iface, buffer) return await codec_v1.read_message(iface, buffer) - async def write_message(iface: WireInterface, message: Message) -> None: + async def write_message(self, iface: WireInterface, message: Message) -> None: if utils.USE_THP: return thp_v1.write_to_wire(iface, message) return codec_v1.write_message(iface, message.type, message.data) diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp_session.py index b0a9579c3..863c5dd7e 100644 --- a/core/src/trezor/wire/thp_session.py +++ b/core/src/trezor/wire/thp_session.py @@ -16,17 +16,6 @@ class ThpError(WireError): pass -class WorkflowState(IntEnum): - NOT_STARTED = 0 - PENDING = 1 - FINISHED = 2 - - -class Workflow: - id: int - workflow_state: WorkflowState - - class SessionState(IntEnum): UNALLOCATED = 0 INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1 @@ -36,19 +25,6 @@ class SessionState(IntEnum): APP_TRAFFIC = 5 -def get_workflow() -> Workflow: - pass # TODO - - -def print_all_test_sessions() -> None: - for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS: - if session is None: - print("none") - else: - print(hexlify(session.session_id).decode("utf-8"), session.state) - - -# def create_autenticated_session(unauthenticated_session: SessionThpCache): storage_thp_cache.start_session() # TODO something like this but for THP raise @@ -135,14 +111,14 @@ def _get_id(iface: WireInterface, cid: int) -> bytearray: return ustruct.pack(">HH", iface.iface_num(), cid) -def _get_authenticated_session_or_none(session_id) -> SessionThpCache: +def _get_authenticated_session_or_none(session_id) -> SessionThpCache | None: for authenticated_session in storage_thp_cache._SESSIONS: if authenticated_session.session_id == session_id: return authenticated_session return None -def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache: +def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None: for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS: if unauthenticated_session.session_id == session_id: return unauthenticated_session @@ -162,5 +138,5 @@ def _decode_session_state(state: bytearray) -> int: return ustruct.unpack("B", state)[0] -def _encode_session_state(state: SessionState) -> bytearray: - return ustruct.pack("B", state) +def _encode_session_state(state: SessionState) -> bytes: + return SessionState.to_bytes(state, 1, "big") diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 25efb6b7f..6e609bdd4 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -73,7 +73,7 @@ class InterruptingInitPacket: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: msg = await read_message_or_init_packet(iface, buffer) while type(msg) is not Message: - if msg is InterruptingInitPacket: + if isinstance(msg, InterruptingInitPacket): msg = await read_message_or_init_packet(iface, buffer, msg.initReport) else: raise ThpError("Unexpected output of read_message_or_init_packet:") @@ -104,6 +104,9 @@ async def read_message_or_init_packet( report = None continue + if report is None: + raise ThpError("Reading failed unexpectedly, report is None.") + payload_length = ustruct.unpack(">H", report[3:])[0] payload = _get_buffer_for_payload(payload_length, buffer) header = InitHeader(ctrl_byte, cid, payload_length) @@ -352,7 +355,7 @@ def _get_new_channel_id() -> int: return THP.get_next_channel_id() -def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool: +def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytearray) -> bool: data_checksum = _compute_checksum_bytes(data) return checksum == data_checksum