|
|
@ -6,14 +6,7 @@ from trezor import log, loop, protobuf, utils, workflow
|
|
|
|
from trezor.enums import FailureType
|
|
|
|
from trezor.enums import FailureType
|
|
|
|
from trezor.wire.thp import interface_manager, received_message_handler
|
|
|
|
from trezor.wire.thp import interface_manager, received_message_handler
|
|
|
|
|
|
|
|
|
|
|
|
from . import (
|
|
|
|
from . import ChannelState, checksum, control_byte, crypto, memory_manager
|
|
|
|
ChannelContext,
|
|
|
|
|
|
|
|
ChannelState,
|
|
|
|
|
|
|
|
checksum,
|
|
|
|
|
|
|
|
control_byte,
|
|
|
|
|
|
|
|
crypto,
|
|
|
|
|
|
|
|
memory_manager,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
from . import thp_session as THP
|
|
|
|
from . import thp_session as THP
|
|
|
|
from .checksum import CHECKSUM_LENGTH
|
|
|
|
from .checksum import CHECKSUM_LENGTH
|
|
|
|
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
|
|
|
|
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
|
|
|
@ -31,19 +24,32 @@ if __debug__:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
|
|
|
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from . import ChannelContext, PairingContext
|
|
|
|
|
|
|
|
from .session_context import SessionContext
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ChannelContext = object
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Channel(ChannelContext):
|
|
|
|
class Channel:
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
if __debug__:
|
|
|
|
if __debug__:
|
|
|
|
log.debug(__name__, "channel initialization")
|
|
|
|
log.debug(__name__, "channel initialization")
|
|
|
|
iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
|
|
|
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
|
|
|
super().__init__(iface, channel_cache)
|
|
|
|
self.channel_cache: ChannelCache = channel_cache
|
|
|
|
self.channel_cache = channel_cache
|
|
|
|
|
|
|
|
self.is_cont_packet_expected: bool = False
|
|
|
|
self.is_cont_packet_expected: bool = False
|
|
|
|
self.expected_payload_length: int = 0
|
|
|
|
self.expected_payload_length: int = 0
|
|
|
|
self.bytes_read: int = 0
|
|
|
|
self.bytes_read: int = 0
|
|
|
|
|
|
|
|
self.buffer: utils.BufferType
|
|
|
|
|
|
|
|
self.channel_id: bytes = channel_cache.channel_id
|
|
|
|
|
|
|
|
self.selected_pairing_methods = []
|
|
|
|
|
|
|
|
self.sessions: dict[int, SessionContext] = {}
|
|
|
|
|
|
|
|
self.waiting_for_ack_timeout: loop.spawn | None = None
|
|
|
|
|
|
|
|
self.write_task_spawn: loop.spawn | None = None
|
|
|
|
|
|
|
|
self.connection_context: PairingContext | None = None
|
|
|
|
|
|
|
|
|
|
|
|
# ACCESS TO CHANNEL_DATA
|
|
|
|
# ACCESS TO CHANNEL_DATA
|
|
|
|
|
|
|
|
def get_channel_id_int(self) -> int:
|
|
|
|
|
|
|
|
return int.from_bytes(self.channel_id, "big")
|
|
|
|
|
|
|
|
|
|
|
|
def get_channel_state(self) -> int:
|
|
|
|
def get_channel_state(self) -> int:
|
|
|
|
state = int.from_bytes(self.channel_cache.state, "big")
|
|
|
|
state = int.from_bytes(self.channel_cache.state, "big")
|
|
|
@ -168,7 +174,7 @@ class Channel(ChannelContext):
|
|
|
|
if __debug__:
|
|
|
|
if __debug__:
|
|
|
|
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
|
|
|
|
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
|
|
|
|
noise_payload_len = memory_manager.encode_into_buffer(
|
|
|
|
noise_payload_len = memory_manager.encode_into_buffer(
|
|
|
|
memoryview(self.buffer), msg, session_id
|
|
|
|
self.buffer, msg, session_id
|
|
|
|
)
|
|
|
|
)
|
|
|
|
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
|
|
|
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
|
|
|
|
|
|
|
|
|
|
|