Remake ChannelContext, change buffer types

M1nd3r/thp2
M1nd3r 3 weeks ago
parent cad4b5ee55
commit 9a10fa8998

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import ChannelContext, SessionState
from trezor.wire.thp import SessionState
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
async def create_new_session(

@ -1,4 +1,4 @@
from trezor import log, protobuf
from trezor import protobuf
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
@ -25,6 +25,9 @@ from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers
if __debug__:
from trezor import log
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType

@ -3,12 +3,43 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
from trezorio import WireInterface
from typing import Protocol
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
class ChannelContext(Protocol):
buffer: utils.BufferType
iface: WireInterface
channel_id: bytes
channel_cache: ChannelCache
selected_pairing_methods = [] # TODO add type
sessions: dict[int, SessionContext]
waiting_for_ack_timeout: loop.spawn | None
write_task_spawn: loop.spawn | None
connection_context: PairingContext | None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: "ChannelState") -> None: ...
async def write(
self, msg: protobuf.MessageType, session_id: int = 0
) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(
self, ctrl_byte: int, payload: bytes
) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int: ...
else:
IntEnum = object
@ -36,29 +67,6 @@ class WireInterfaceType(IntEnum):
BLE = 2
class ChannelContext:
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
self.buffer: utils.BufferType
self.iface: WireInterface = iface
self.channel_id: bytes = channel_cache.channel_id
self.channel_cache: ChannelCache = channel_cache
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: ChannelState) -> None: ...
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,

@ -6,14 +6,7 @@ from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.wire.thp import interface_manager, received_message_handler
from . import (
ChannelContext,
ChannelState,
checksum,
control_byte,
crypto,
memory_manager,
)
from . import ChannelState, checksum, control_byte, crypto, memory_manager
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
@ -31,19 +24,32 @@ if __debug__:
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from . import ChannelContext, PairingContext
from .session_context import SessionContext
else:
ChannelContext = object
class Channel(ChannelContext):
class Channel:
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__:
log.debug(__name__, "channel initialization")
iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache)
self.channel_cache = channel_cache
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
@ -168,7 +174,7 @@ class Channel(ChannelContext):
if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = memory_manager.encode_into_buffer(
memoryview(self.buffer), msg, session_id
self.buffer, msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len])

@ -46,7 +46,7 @@ def select_buffer(
def encode_into_buffer(
buffer: memoryview, msg: protobuf.MessageType, session_id: int
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
@ -58,7 +58,7 @@ def encode_into_buffer(
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = memoryview(bytearray(required_min_size))
buffer = bytearray(required_min_size)
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(

@ -1,19 +1,23 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, workflow
from trezor import loop, protobuf, workflow
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, MessageWithType
from . import ChannelContext
from .session_context import UnexpectedMessageWithType
if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports]
from typing import Container
from . import ChannelContext
pass
if __debug__:
from trezor import log
class PairingContext(Context):
def __init__(self, channel_ctx: ChannelContext) -> None:

@ -19,7 +19,6 @@ from trezor.wire.thp.thp_messages import (
)
from . import (
ChannelContext,
ChannelState,
SessionState,
checksum,
@ -33,6 +32,8 @@ from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from . import ChannelContext
if __debug__:
from . import state_to_str

@ -6,7 +6,7 @@ from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType
from . import ChannelContext, SessionState
from . import SessionState
if TYPE_CHECKING:
from typing import ( # pyright: ignore[reportShadowedImports]
@ -15,6 +15,8 @@ if TYPE_CHECKING:
Container,
)
from . import ChannelContext
pass
_EXIT_LOOP = True

@ -1,7 +1,15 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import log, loop
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import SessionContext
from trezor import loop
from .session_context import SessionContext
if __debug__:
from trezor import log
if TYPE_CHECKING:
from . import ChannelContext
def create_new_session(channel_ctx: ChannelContext) -> SessionContext:

Loading…
Cancel
Save