From 121ed1f5307469691dabaab243e7c0a328b821e2 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 6 Aug 2024 15:45:44 +0200 Subject: [PATCH] refactor(trezorlib): decouple protocol from handler [no changelog] --- python/src/trezorlib/transport/hid.py | 6 +- python/src/trezorlib/transport/protocol.py | 68 +++++++++++++--------- python/src/trezorlib/transport/udp.py | 6 +- python/src/trezorlib/transport/webusb.py | 7 ++- 4 files changed, 54 insertions(+), 33 deletions(-) diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd..60d5f8a30 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, List, Optional from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel +from ..transport.protocol import Handle from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -127,7 +128,10 @@ class HidTransport(ProtocolBasedTransport): self.device = device self.handle = HidHandle(device["path"], device["serial_number"]) - super().__init__(protocol=ProtocolV1(self.handle)) + super().__init__(protocol=ProtocolV1()) + + def get_handle(self) -> Handle: + return self.handle def get_path(self) -> str: return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index a5a0ee6be..d44fe016c 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -16,7 +16,7 @@ import logging import struct -from typing import Tuple +from typing import Callable, Tuple from typing_extensions import Protocol as StructuralType @@ -71,25 +71,18 @@ class Protocol: its messages. """ - def __init__(self, handle: Handle) -> None: - self.handle = handle + def __init__(self) -> None: self.session_counter = 0 - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: + def read(self, read_chunk: Callable[[], bytes]) -> MessagePayload: raise NotImplementedError - def write(self, message_type: int, message_data: bytes) -> None: + def write( + self, + message_type: int, + message_data: bytes, + write_chunk: Callable[[bytes], None], + ) -> None: raise NotImplementedError @@ -102,18 +95,30 @@ class ProtocolBasedTransport(Transport): def __init__(self, protocol: Protocol) -> None: self.protocol = protocol + self.session_counter = 0 def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) + self.protocol.write( + message_type, + message_data, + self.get_handle().write_chunk, + ) def read(self) -> MessagePayload: - return self.protocol.read() + return self.protocol.read(self.get_handle().read_chunk) def begin_session(self) -> None: - self.protocol.begin_session() + if self.session_counter == 0: + self.get_handle().open() + self.session_counter += 1 def end_session(self) -> None: - self.protocol.end_session() + self.session_counter = max(self.session_counter - 1, 0) + if self.session_counter == 0: + self.get_handle().close() + + def get_handle(self) -> Handle: + raise NotImplementedError class ProtocolV1(Protocol): @@ -123,7 +128,12 @@ class ProtocolV1(Protocol): HEADER_LEN = struct.calcsize(">HL") - def write(self, message_type: int, message_data: bytes) -> None: + def write( + self, + message_type: int, + message_data: bytes, + write_chunk: Callable[[bytes], None], + ) -> None: header = struct.pack(">HL", message_type, len(message_data)) buffer = bytearray(b"##" + header + message_data) @@ -131,23 +141,23 @@ class ProtocolV1(Protocol): # Report ID, data padded to 63 bytes chunk = b"?" + buffer[: REPLEN - 1] chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) + write_chunk(chunk) buffer = buffer[63:] - def read(self) -> MessagePayload: + def read(self, read_chunk: Callable[[], bytes]) -> MessagePayload: buffer = bytearray() # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() + msg_type, datalen, first_chunk = self.read_first(read_chunk) buffer.extend(first_chunk) # Read the rest of the message while len(buffer) < datalen: - buffer.extend(self.read_next()) + buffer.extend(self.read_next(read_chunk)) return msg_type, buffer[:datalen] - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() + def read_first(self, read_chunk: Callable[[], bytes]) -> Tuple[int, int, bytes]: + chunk = read_chunk() if chunk[:3] != b"?##": raise RuntimeError("Unexpected magic characters") try: @@ -158,8 +168,8 @@ class ProtocolV1(Protocol): data = chunk[3 + self.HEADER_LEN :] return msg_type, datalen, data - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() + def read_next(self, read_chunk: Callable[[], bytes]) -> bytes: + chunk = read_chunk() if chunk[:1] != b"?": raise RuntimeError("Unexpected magic characters") return chunk[1:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c..6673ebbdf 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -20,6 +20,7 @@ import time from typing import TYPE_CHECKING, Iterable, Optional from ..log import DUMP_PACKETS +from ..transport.protocol import Handle from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -49,7 +50,10 @@ class UdpTransport(ProtocolBasedTransport): self.device = (host, port) self.socket: Optional[socket.socket] = None - super().__init__(protocol=ProtocolV1(self)) + super().__init__(protocol=ProtocolV1()) + + def get_handle(self) -> Handle: + return self def get_path(self) -> str: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 1b60df61b..d5d4f87f4 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -23,7 +23,7 @@ from typing import Iterable, List, Optional from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from .protocol import Handle, ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -112,7 +112,10 @@ class WebUsbTransport(ProtocolBasedTransport): self.handle = handle self.debug = debug - super().__init__(protocol=ProtocolV1(handle)) + super().__init__(protocol=ProtocolV1()) + + def get_handle(self) -> Handle: + return self.handle def get_path(self) -> str: return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"