|
|
|
@ -3,12 +3,43 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from enum import IntEnum
|
|
|
|
|
from trezorio import WireInterface
|
|
|
|
|
from typing import Protocol
|
|
|
|
|
|
|
|
|
|
from storage.cache_thp import ChannelCache
|
|
|
|
|
from trezor import loop, protobuf, utils
|
|
|
|
|
from trezor.enums import FailureType
|
|
|
|
|
from trezor.wire.thp.pairing_context import PairingContext
|
|
|
|
|
from trezor.wire.thp.session_context import SessionContext
|
|
|
|
|
|
|
|
|
|
class ChannelContext(Protocol):
|
|
|
|
|
buffer: utils.BufferType
|
|
|
|
|
iface: WireInterface
|
|
|
|
|
channel_id: bytes
|
|
|
|
|
channel_cache: ChannelCache
|
|
|
|
|
selected_pairing_methods = [] # TODO add type
|
|
|
|
|
sessions: dict[int, SessionContext]
|
|
|
|
|
waiting_for_ack_timeout: loop.spawn | None
|
|
|
|
|
write_task_spawn: loop.spawn | None
|
|
|
|
|
connection_context: PairingContext | None
|
|
|
|
|
|
|
|
|
|
def get_channel_state(self) -> int: ...
|
|
|
|
|
|
|
|
|
|
def set_channel_state(self, state: "ChannelState") -> None: ...
|
|
|
|
|
|
|
|
|
|
async def write(
|
|
|
|
|
self, msg: protobuf.MessageType, session_id: int = 0
|
|
|
|
|
) -> None: ...
|
|
|
|
|
|
|
|
|
|
async def write_error(self, err_type: FailureType, message: str) -> None: ...
|
|
|
|
|
|
|
|
|
|
async def write_handshake_message(
|
|
|
|
|
self, ctrl_byte: int, payload: bytes
|
|
|
|
|
) -> None: ...
|
|
|
|
|
|
|
|
|
|
def decrypt_buffer(self, message_length: int) -> None: ...
|
|
|
|
|
|
|
|
|
|
def get_channel_id_int(self) -> int: ...
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
IntEnum = object
|
|
|
|
|
|
|
|
|
@ -36,29 +67,6 @@ class WireInterfaceType(IntEnum):
|
|
|
|
|
BLE = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelContext:
|
|
|
|
|
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
|
|
|
|
|
self.buffer: utils.BufferType
|
|
|
|
|
self.iface: WireInterface = iface
|
|
|
|
|
self.channel_id: bytes = channel_cache.channel_id
|
|
|
|
|
self.channel_cache: ChannelCache = channel_cache
|
|
|
|
|
self.selected_pairing_methods = []
|
|
|
|
|
self.sessions: dict[int, SessionContext] = {}
|
|
|
|
|
self.waiting_for_ack_timeout: loop.spawn | None = None
|
|
|
|
|
self.write_task_spawn: loop.spawn | None = None
|
|
|
|
|
self.connection_context: PairingContext | None = None
|
|
|
|
|
|
|
|
|
|
def get_channel_state(self) -> int: ...
|
|
|
|
|
def set_channel_state(self, state: ChannelState) -> None: ...
|
|
|
|
|
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
|
|
|
|
|
async def write_error(self, err_type: FailureType, message: str) -> None: ...
|
|
|
|
|
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
|
|
|
|
|
def decrypt_buffer(self, message_length: int) -> None: ...
|
|
|
|
|
|
|
|
|
|
def get_channel_id_int(self) -> int:
|
|
|
|
|
return int.from_bytes(self.channel_id, "big")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_channel_state_pairing(state: int) -> bool:
|
|
|
|
|
if state in (
|
|
|
|
|
ChannelState.TP1,
|
|
|
|
|