From efcdd6843bb6ad32a59f8f21aef0e85c3e584a01 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:21:19 +0100 Subject: [PATCH] chore(core): adapt trezorlib transports to session based [no changelog] --- python/src/trezorlib/transport/bridge.py | 39 ++++- .../trezorlib/transport/thp/channel_data.py | 47 ++++++ .../transport/thp/channel_database.py | 142 ++++++++++++++++++ 3 files changed, 225 insertions(+), 3 deletions(-) create mode 100644 python/src/trezorlib/transport/thp/channel_data.py create mode 100644 python/src/trezorlib/transport/thp/channel_database.py diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index 65c545a768..0c84637071 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -21,6 +21,7 @@ import typing as t import requests +from ..client import ProtocolVersion from ..log import DUMP_PACKETS from . import DeviceIsBusy, Transport, TransportException @@ -63,6 +64,35 @@ def is_legacy_bridge() -> bool: return get_bridge_version() < TREZORD_VERSION_MODERN +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + + protocol_version = ProtocolVersion.PROTOCOL_V1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + _ = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1 + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] + + class BridgeHandle: def __init__(self, transport: "BridgeTransport") -> None: self.transport = transport @@ -152,9 +182,12 @@ class BridgeTransport(Transport): ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] diff --git a/python/src/trezorlib/transport/thp/channel_data.py b/python/src/trezorlib/transport/thp/channel_data.py new file mode 100644 index 0000000000..4d9d11d8d0 --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_data.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + + def __init__( + self, + protocol_version_major: int, + protocol_version_minor: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + handshake_hash: bytes, + ) -> None: + self.protocol_version_major: int = protocol_version_major + self.protocol_version_minor: int = protocol_version_minor + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + self.handshake_hash: str = hexlify(handshake_hash).decode() + + def to_dict(self): + return { + "protocol_version_major": self.protocol_version_major, + "protocol_version_minor": self.protocol_version_minor, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + "handshake_hash": self.handshake_hash, + } diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py new file mode 100644 index 0000000000..51dc850b26 --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import json +import logging +import os +import typing as t + +from .channel_data import ChannelData +from .protocol_and_channel import Channel + +LOG = logging.getLogger(__name__) + +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: + + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + + def save_channel(self, new_channel: Channel): ... + + def remove_channel(self, transport_path: str) -> None: ... + + +class DummyChannelDatabase(ChannelDatabase): + + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + + def clear_stored_channels(self) -> None: + pass + + def save_channel(self, new_channel: Channel): + pass + + def remove_channel(self, transport_path: str) -> None: + pass + + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self._read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def _read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def _save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: Channel): + + LOG.debug("save channel") + channels = self._read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self._save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self._save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self._read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self._save_all_channels(remaining_channels) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version_major=dict["protocol_version_minor"], + protocol_version_minor=dict["protocol_version_major"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + handshake_hash=bytes.fromhex(dict["handshake_hash"]), + ) + + +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: + json.dump([], f) + + +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") + + db = JsonChannelDatabase(DATA_PATH)