mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-17 21:48:47 +00:00
chore(core): adapt trezorlib transports to session based
[no changelog]
This commit is contained in:
parent
0743b65159
commit
efcdd6843b
@ -21,6 +21,7 @@ import typing as t
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from ..client import ProtocolVersion
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import DeviceIsBusy, Transport, TransportException
|
from . import DeviceIsBusy, Transport, TransportException
|
||||||
|
|
||||||
@ -63,6 +64,35 @@ def is_legacy_bridge() -> bool:
|
|||||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||||
|
|
||||||
|
|
||||||
|
def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||||
|
from .. import mapping, messages
|
||||||
|
|
||||||
|
protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||||
|
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||||
|
transport.deprecated_begin_session()
|
||||||
|
transport.deprecated_write(request_type, request_data)
|
||||||
|
|
||||||
|
response_type, response_data = transport.deprecated_read()
|
||||||
|
_ = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||||
|
transport.deprecated_begin_session()
|
||||||
|
|
||||||
|
return protocol_version
|
||||||
|
|
||||||
|
|
||||||
|
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||||
|
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
|
||||||
|
if not is_valid:
|
||||||
|
LOG.warning("Detected unsupported Bridge transport!")
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
|
||||||
|
def filter_invalid_bridge_transports(
|
||||||
|
transports: t.Iterable["BridgeTransport"],
|
||||||
|
) -> t.Sequence["BridgeTransport"]:
|
||||||
|
"""Filters out invalid bridge transports. Keeps only valid ones."""
|
||||||
|
return [t for t in transports if _is_transport_valid(t)]
|
||||||
|
|
||||||
|
|
||||||
class BridgeHandle:
|
class BridgeHandle:
|
||||||
def __init__(self, transport: "BridgeTransport") -> None:
|
def __init__(self, transport: "BridgeTransport") -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
@ -152,9 +182,12 @@ class BridgeTransport(Transport):
|
|||||||
) -> t.Iterable["BridgeTransport"]:
|
) -> t.Iterable["BridgeTransport"]:
|
||||||
try:
|
try:
|
||||||
legacy = is_legacy_bridge()
|
legacy = is_legacy_bridge()
|
||||||
return [
|
return filter_invalid_bridge_transports(
|
||||||
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
|
[
|
||||||
]
|
BridgeTransport(dev, legacy)
|
||||||
|
for dev in call_bridge("enumerate").json()
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelData:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
protocol_version_major: int,
|
||||||
|
protocol_version_minor: int,
|
||||||
|
transport_path: str,
|
||||||
|
channel_id: int,
|
||||||
|
key_request: bytes,
|
||||||
|
key_response: bytes,
|
||||||
|
nonce_request: int,
|
||||||
|
nonce_response: int,
|
||||||
|
sync_bit_send: int,
|
||||||
|
sync_bit_receive: int,
|
||||||
|
handshake_hash: bytes,
|
||||||
|
) -> None:
|
||||||
|
self.protocol_version_major: int = protocol_version_major
|
||||||
|
self.protocol_version_minor: int = protocol_version_minor
|
||||||
|
self.transport_path: str = transport_path
|
||||||
|
self.channel_id: int = channel_id
|
||||||
|
self.key_request: str = hexlify(key_request).decode()
|
||||||
|
self.key_response: str = hexlify(key_response).decode()
|
||||||
|
self.nonce_request: int = nonce_request
|
||||||
|
self.nonce_response: int = nonce_response
|
||||||
|
self.sync_bit_receive: int = sync_bit_receive
|
||||||
|
self.sync_bit_send: int = sync_bit_send
|
||||||
|
self.handshake_hash: str = hexlify(handshake_hash).decode()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"protocol_version_major": self.protocol_version_major,
|
||||||
|
"protocol_version_minor": self.protocol_version_minor,
|
||||||
|
"transport_path": self.transport_path,
|
||||||
|
"channel_id": self.channel_id,
|
||||||
|
"key_request": self.key_request,
|
||||||
|
"key_response": self.key_response,
|
||||||
|
"nonce_request": self.nonce_request,
|
||||||
|
"nonce_response": self.nonce_response,
|
||||||
|
"sync_bit_send": self.sync_bit_send,
|
||||||
|
"sync_bit_receive": self.sync_bit_receive,
|
||||||
|
"handshake_hash": self.handshake_hash,
|
||||||
|
}
|
142
python/src/trezorlib/transport/thp/channel_database.py
Normal file
142
python/src/trezorlib/transport/thp/channel_database.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from .channel_data import ChannelData
|
||||||
|
from .protocol_and_channel import Channel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db: "ChannelDatabase | None" = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_db() -> ChannelDatabase:
|
||||||
|
if db is None:
|
||||||
|
set_channel_database(should_not_store=True)
|
||||||
|
assert db is not None
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDatabase:
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||||
|
def clear_stored_channels(self) -> None: ...
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: Channel): ...
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChannelDatabase(ChannelDatabase):
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: Channel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChannelDatabase(ChannelDatabase):
|
||||||
|
def __init__(self, data_path: str) -> None:
|
||||||
|
self.data_path = data_path
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
dicts = self._read_all_channels()
|
||||||
|
return [dict_to_channel_data(d) for d in dicts]
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
LOG.debug("Clearing contents of %s", self.data_path)
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
try:
|
||||||
|
os.remove(self.data_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||||
|
|
||||||
|
def _read_all_channels(self) -> t.List:
|
||||||
|
ensure_file_exists(self.data_path)
|
||||||
|
with open(self.data_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
LOG.debug("saving all channels")
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump(channels, f, indent=4)
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: Channel):
|
||||||
|
|
||||||
|
LOG.debug("save channel")
|
||||||
|
channels = self._read_all_channels()
|
||||||
|
transport_path = new_channel.transport.get_path()
|
||||||
|
|
||||||
|
# If the channel is found in database: replace the old entry by the new
|
||||||
|
for i, channel in enumerate(channels):
|
||||||
|
if channel["transport_path"] == transport_path:
|
||||||
|
LOG.debug("Modified channel entry for %s", transport_path)
|
||||||
|
channels[i] = new_channel.get_channel_data().to_dict()
|
||||||
|
self._save_all_channels(channels)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Channel was not found: add a new channel entry
|
||||||
|
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||||
|
channels.append(new_channel.get_channel_data().to_dict())
|
||||||
|
self._save_all_channels(channels)
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
LOG.debug(
|
||||||
|
"Removing channel with path %s from the channel database.",
|
||||||
|
transport_path,
|
||||||
|
)
|
||||||
|
channels = self._read_all_channels()
|
||||||
|
remaining_channels = [
|
||||||
|
ch for ch in channels if ch["transport_path"] != transport_path
|
||||||
|
]
|
||||||
|
self._save_all_channels(remaining_channels)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version_major=dict["protocol_version_minor"],
|
||||||
|
protocol_version_minor=dict["protocol_version_major"],
|
||||||
|
transport_path=dict["transport_path"],
|
||||||
|
channel_id=dict["channel_id"],
|
||||||
|
key_request=bytes.fromhex(dict["key_request"]),
|
||||||
|
key_response=bytes.fromhex(dict["key_response"]),
|
||||||
|
nonce_request=dict["nonce_request"],
|
||||||
|
nonce_response=dict["nonce_response"],
|
||||||
|
sync_bit_send=dict["sync_bit_send"],
|
||||||
|
sync_bit_receive=dict["sync_bit_receive"],
|
||||||
|
handshake_hash=bytes.fromhex(dict["handshake_hash"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_file_exists(file_path: str) -> None:
|
||||||
|
LOG.debug("checking if file %s exists", file_path)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
|
||||||
|
|
||||||
|
def set_channel_database(should_not_store: bool):
|
||||||
|
global db
|
||||||
|
if should_not_store:
|
||||||
|
db = DummyChannelDatabase()
|
||||||
|
else:
|
||||||
|
from platformdirs import user_cache_dir
|
||||||
|
|
||||||
|
APP_NAME = "@trezor" # TODO
|
||||||
|
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||||
|
|
||||||
|
db = JsonChannelDatabase(DATA_PATH)
|
Loading…
Reference in New Issue
Block a user