1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-23 13:51:00 +00:00

feat(python): rework channel database

[no changelog]
This commit is contained in:
M1nd3r 2024-11-19 20:23:10 +01:00
parent 71229b0a0a
commit b2b786d9e0
6 changed files with 138 additions and 97 deletions

View File

@ -8,3 +8,4 @@ typing_extensions>=4.7.1
construct-classes>=0.1.2 construct-classes>=0.1.2
appdirs>=1.4.4 appdirs>=1.4.4
cryptography >=43.0.3 cryptography >=43.0.3
platformdirs >=2

View File

@ -29,7 +29,7 @@ from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport import Transport from ..transport import Transport
from ..transport.thp import channel_database from ..transport.thp.channel_database import get_channel_db
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -102,7 +102,7 @@ def get_passphrase(
def get_client(transport: Transport) -> TrezorClient: def get_client(transport: Transport) -> TrezorClient:
stored_channels = channel_database.load_stored_channels() stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels] stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path() path = transport.get_path()
if path in stored_transport_paths: if path in stored_transport_paths:
@ -115,7 +115,7 @@ def get_client(transport: Transport) -> TrezorClient:
) )
except Exception: except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.") LOG.debug("Failed to resume a channel. Replacing by a new one.")
channel_database.remove_channel(path) get_channel_db().remove_channel(path)
client = TrezorClient(transport) client = TrezorClient(transport)
else: else:
client = TrezorClient(transport) client = TrezorClient(transport)
@ -355,7 +355,7 @@ def with_client(
try: try:
return func(client, *args, **kwargs) return func(client, *args, **kwargs)
finally: finally:
channel_database.save_channel(client.protocol) get_channel_db().save_channel(client.protocol)
# if not session_was_resumed: # if not session_was_resumed:
# try: # try:
# client.end_session() # client.end_session()

View File

@ -29,6 +29,7 @@ from ..client import TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session from ..transport.session import Session
from ..transport.thp import channel_database from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
@ -196,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record", "--record",
help="Record screen changes into a specified directory.", help="Record screen changes into a specified directory.",
) )
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__) @click.version_option(version=__version__)
@click.pass_context @click.pass_context
def cli_main( def cli_main(
@ -207,9 +215,10 @@ def cli_main(
script: bool, script: bool,
session_id: Optional[str], session_id: Optional[str],
record: Optional[str], record: Optional[str],
no_store: bool,
) -> None: ) -> None:
configure_logging(verbose) configure_logging(verbose)
channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None bytes_session_id: Optional[bytes] = None
if session_id is not None: if session_id is not None:
try: try:
@ -296,10 +305,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
try: try:
client = get_client(transport) client = get_client(transport)
description = format_device_name(client.features) description = format_device_name(client.features)
# json_string = channel_database.channel_to_str(client.protocol) get_channel_db().save_channel(client.protocol)
# print(json_string)
channel_database.save_channel(client.protocol)
# client.end_session()
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception:
@ -376,9 +382,14 @@ def clear_session(session: "Session") -> None:
@cli.command() @cli.command()
def new_clear_session() -> None: def delete_channels() -> None:
"""New Clear session (remove cached channels from trezorlib).""" """
channel_database.clear_stored_channels() Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted.
"""
get_channel_db().clear_stored_channels()
@cli.command() @cli.command()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json import json
import logging import logging
import os import os
@ -8,39 +10,104 @@ from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
if True: db: "ChannelDatabase | None" = None
from platformdirs import user_cache_dir, user_config_dir
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
CONFIG_PATH = os.path.join(user_config_dir(appname=APP_NAME), "config.json")
else:
DATA_PATH = os.path.join("./channel_data.json")
CONFIG_PATH = os.path.join("./config.json")
class ChannelDatabase: # TODO not finished def get_channel_db() -> ChannelDatabase:
should_store: bool = False if db is None:
set_channel_database(should_not_store=True)
def __init__( assert db is not None
self, config_path: str = CONFIG_PATH, data_path: str = DATA_PATH return db
) -> None:
if not os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "w") as f:
json.dump([], f)
def load_stored_channels() -> t.List[ChannelData]: class ChannelDatabase:
dicts = read_all_channels()
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def read_all_channels(self) -> t.List: ...
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
def save_channel(self, new_channel: ProtocolAndChannel): ...
def remove_channel(self, transport_path: str) -> None: ...
class DummyChannelDatabase(ChannelDatabase):
def load_stored_channels(self) -> t.List[ChannelData]:
return []
def clear_stored_channels(self) -> None:
pass
def read_all_channels(self) -> t.List:
return []
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
return
def save_channel(self, new_channel: ProtocolAndChannel):
pass
def remove_channel(self, transport_path: str) -> None:
pass
class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()
def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self.read_all_channels()
return [dict_to_channel_data(d) for d in dicts] return [dict_to_channel_data(d) for d in dicts]
def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
def channel_to_str(channel: ProtocolAndChannel) -> str: def read_all_channels(self) -> t.List:
return json.dumps(channel.get_channel_data().to_dict()) ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def str_to_channel_data(channel_data: str) -> ChannelData: def save_channel(self, new_channel: ProtocolAndChannel):
return dict_to_channel_data(json.loads(channel_data))
LOG.debug("save channel")
channels = self.read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self.save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self.save_all_channels(channels)
def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self.read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self.save_all_channels(remaining_channels)
def dict_to_channel_data(dict: t.Dict) -> ChannelData: def dict_to_channel_data(dict: t.Dict) -> ChannelData:
@ -57,63 +124,23 @@ def dict_to_channel_data(dict: t.Dict) -> ChannelData:
) )
def ensure_file_exists() -> None: def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", DATA_PATH) LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(DATA_PATH): if not os.path.exists(file_path):
os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", DATA_PATH) LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(DATA_PATH, "w") as f: with open(file_path, "w") as f:
json.dump([], f) json.dump([], f)
def clear_stored_channels() -> None: def set_channel_database(should_not_store: bool):
LOG.debug("Clearing contents of %s", DATA_PATH) global db
with open(DATA_PATH, "w") as f: if should_not_store:
json.dump([], f) db = DummyChannelDatabase()
try: else:
os.remove(DATA_PATH) from platformdirs import user_cache_dir
except Exception as e:
LOG.exception("Failed to delete %s (%s)", DATA_PATH, str(type(e)))
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
def read_all_channels() -> t.List: db = JsonChannelDatabase(DATA_PATH)
ensure_file_exists()
with open(DATA_PATH, "r") as f:
return json.load(f)
def save_all_channels(channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(DATA_PATH, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
save_all_channels(channels)
def remove_channel(transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
save_all_channels(remaining_channels)

View File

@ -18,7 +18,8 @@ from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader from ..thp.message_header import MessageHeader
from . import channel_database, control_byte from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import ProtocolAndChannel from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -76,6 +77,7 @@ class ProtocolV2(ProtocolAndChannel):
self.sync_bit_receive = channel_data.sync_bit_receive self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send self.sync_bit_send = channel_data.sync_bit_send
self._has_valid_channel = True self._has_valid_channel = True
self.channel_database: ChannelDatabase = get_channel_db()
def get_channel(self) -> ProtocolV2: def get_channel(self) -> ProtocolV2:
if not self._has_valid_channel: if not self._has_valid_channel:
@ -99,13 +101,13 @@ class ProtocolV2(ProtocolAndChannel):
sid, msg_type, msg_data = self.read_and_decrypt() sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id: if sid != session_id:
raise Exception("Received messsage on a different session.") raise Exception("Received messsage on a different session.")
channel_database.save_channel(self) self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data) return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None: def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg) msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data) self._encrypt_and_write(session_id, msg_type, msg_data)
channel_database.save_channel(self) self.channel_database.save_channel(self)
def get_features(self) -> messages.Features: def get_features(self) -> messages.Features:
if not self._has_valid_channel: if not self._has_valid_channel:

View File

@ -322,9 +322,9 @@ def client(
# Get a new client # Get a new client
_raw_client = _get_raw_client(request) _raw_client = _get_raw_client(request)
from trezorlib.transport.thp import channel_database from trezorlib.transport.thp.channel_database import get_channel_db
channel_database.clear_stored_channels() get_channel_db().clear_stored_channels()
_raw_client.protocol = None _raw_client.protocol = None
_raw_client.__init__( _raw_client.__init__(
transport=_raw_client.transport, transport=_raw_client.transport,