mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-23 13:51:00 +00:00
feat(python): rework channel database
[no changelog]
This commit is contained in:
parent
71229b0a0a
commit
b2b786d9e0
@ -8,3 +8,4 @@ typing_extensions>=4.7.1
|
|||||||
construct-classes>=0.1.2
|
construct-classes>=0.1.2
|
||||||
appdirs>=1.4.4
|
appdirs>=1.4.4
|
||||||
cryptography >=43.0.3
|
cryptography >=43.0.3
|
||||||
|
platformdirs >=2
|
||||||
|
@ -29,7 +29,7 @@ from .. import exceptions, transport, ui
|
|||||||
from ..client import TrezorClient
|
from ..client import TrezorClient
|
||||||
from ..messages import Capability
|
from ..messages import Capability
|
||||||
from ..transport import Transport
|
from ..transport import Transport
|
||||||
from ..transport.thp import channel_database
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ def get_passphrase(
|
|||||||
|
|
||||||
|
|
||||||
def get_client(transport: Transport) -> TrezorClient:
|
def get_client(transport: Transport) -> TrezorClient:
|
||||||
stored_channels = channel_database.load_stored_channels()
|
stored_channels = get_channel_db().load_stored_channels()
|
||||||
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||||
path = transport.get_path()
|
path = transport.get_path()
|
||||||
if path in stored_transport_paths:
|
if path in stored_transport_paths:
|
||||||
@ -115,7 +115,7 @@ def get_client(transport: Transport) -> TrezorClient:
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.debug("Failed to resume a channel. Replacing by a new one.")
|
LOG.debug("Failed to resume a channel. Replacing by a new one.")
|
||||||
channel_database.remove_channel(path)
|
get_channel_db().remove_channel(path)
|
||||||
client = TrezorClient(transport)
|
client = TrezorClient(transport)
|
||||||
else:
|
else:
|
||||||
client = TrezorClient(transport)
|
client = TrezorClient(transport)
|
||||||
@ -355,7 +355,7 @@ def with_client(
|
|||||||
try:
|
try:
|
||||||
return func(client, *args, **kwargs)
|
return func(client, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
channel_database.save_channel(client.protocol)
|
get_channel_db().save_channel(client.protocol)
|
||||||
# if not session_was_resumed:
|
# if not session_was_resumed:
|
||||||
# try:
|
# try:
|
||||||
# client.end_session()
|
# client.end_session()
|
||||||
|
@ -29,6 +29,7 @@ from ..client import TrezorClient
|
|||||||
from ..transport import DeviceIsBusy, enumerate_devices
|
from ..transport import DeviceIsBusy, enumerate_devices
|
||||||
from ..transport.session import Session
|
from ..transport.session import Session
|
||||||
from ..transport.thp import channel_database
|
from ..transport.thp import channel_database
|
||||||
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
from ..transport.udp import UdpTransport
|
from ..transport.udp import UdpTransport
|
||||||
from . import (
|
from . import (
|
||||||
AliasedGroup,
|
AliasedGroup,
|
||||||
@ -196,6 +197,13 @@ def configure_logging(verbose: int) -> None:
|
|||||||
"--record",
|
"--record",
|
||||||
help="Record screen changes into a specified directory.",
|
help="Record screen changes into a specified directory.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"-n",
|
||||||
|
"--no-store",
|
||||||
|
is_flag=True,
|
||||||
|
help="Do not store channels data between commands.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
@click.version_option(version=__version__)
|
@click.version_option(version=__version__)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_main(
|
def cli_main(
|
||||||
@ -207,9 +215,10 @@ def cli_main(
|
|||||||
script: bool,
|
script: bool,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
record: Optional[str],
|
record: Optional[str],
|
||||||
|
no_store: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
configure_logging(verbose)
|
configure_logging(verbose)
|
||||||
|
channel_database.set_channel_database(should_not_store=no_store)
|
||||||
bytes_session_id: Optional[bytes] = None
|
bytes_session_id: Optional[bytes] = None
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
try:
|
try:
|
||||||
@ -296,10 +305,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
|||||||
try:
|
try:
|
||||||
client = get_client(transport)
|
client = get_client(transport)
|
||||||
description = format_device_name(client.features)
|
description = format_device_name(client.features)
|
||||||
# json_string = channel_database.channel_to_str(client.protocol)
|
get_channel_db().save_channel(client.protocol)
|
||||||
# print(json_string)
|
|
||||||
channel_database.save_channel(client.protocol)
|
|
||||||
# client.end_session()
|
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -376,9 +382,14 @@ def clear_session(session: "Session") -> None:
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def new_clear_session() -> None:
|
def delete_channels() -> None:
|
||||||
"""New Clear session (remove cached channels from trezorlib)."""
|
"""
|
||||||
channel_database.clear_stored_channels()
|
Delete cached channels.
|
||||||
|
|
||||||
|
Do not use together with the `-n` (`--no-store`) flag,
|
||||||
|
as the JSON database will not be deleted.
|
||||||
|
"""
|
||||||
|
get_channel_db().clear_stored_channels()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -8,39 +10,104 @@ from .protocol_and_channel import ProtocolAndChannel
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
if True:
|
db: "ChannelDatabase | None" = None
|
||||||
from platformdirs import user_cache_dir, user_config_dir
|
|
||||||
|
|
||||||
APP_NAME = "@trezor" # TODO
|
|
||||||
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
|
||||||
CONFIG_PATH = os.path.join(user_config_dir(appname=APP_NAME), "config.json")
|
|
||||||
else:
|
|
||||||
DATA_PATH = os.path.join("./channel_data.json")
|
|
||||||
CONFIG_PATH = os.path.join("./config.json")
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelDatabase: # TODO not finished
|
def get_channel_db() -> ChannelDatabase:
|
||||||
should_store: bool = False
|
if db is None:
|
||||||
|
set_channel_database(should_not_store=True)
|
||||||
def __init__(
|
assert db is not None
|
||||||
self, config_path: str = CONFIG_PATH, data_path: str = DATA_PATH
|
return db
|
||||||
) -> None:
|
|
||||||
if not os.path.exists(CONFIG_PATH):
|
|
||||||
with open(CONFIG_PATH, "w") as f:
|
|
||||||
json.dump([], f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_stored_channels() -> t.List[ChannelData]:
|
class ChannelDatabase:
|
||||||
dicts = read_all_channels()
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||||
|
def clear_stored_channels(self) -> None: ...
|
||||||
|
def read_all_channels(self) -> t.List: ...
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
|
||||||
|
def save_channel(self, new_channel: ProtocolAndChannel): ...
|
||||||
|
def remove_channel(self, transport_path: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChannelDatabase(ChannelDatabase):
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def read_all_channels(self) -> t.List:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChannelDatabase(ChannelDatabase):
|
||||||
|
def __init__(self, data_path: str) -> None:
|
||||||
|
self.data_path = data_path
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
dicts = self.read_all_channels()
|
||||||
return [dict_to_channel_data(d) for d in dicts]
|
return [dict_to_channel_data(d) for d in dicts]
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
LOG.debug("Clearing contents of %s", self.data_path)
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
try:
|
||||||
|
os.remove(self.data_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||||
|
|
||||||
def channel_to_str(channel: ProtocolAndChannel) -> str:
|
def read_all_channels(self) -> t.List:
|
||||||
return json.dumps(channel.get_channel_data().to_dict())
|
ensure_file_exists(self.data_path)
|
||||||
|
with open(self.data_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
LOG.debug("saving all channels")
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump(channels, f, indent=4)
|
||||||
|
|
||||||
def str_to_channel_data(channel_data: str) -> ChannelData:
|
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||||
return dict_to_channel_data(json.loads(channel_data))
|
|
||||||
|
LOG.debug("save channel")
|
||||||
|
channels = self.read_all_channels()
|
||||||
|
transport_path = new_channel.transport.get_path()
|
||||||
|
|
||||||
|
# If the channel is found in database: replace the old entry by the new
|
||||||
|
for i, channel in enumerate(channels):
|
||||||
|
if channel["transport_path"] == transport_path:
|
||||||
|
LOG.debug("Modified channel entry for %s", transport_path)
|
||||||
|
channels[i] = new_channel.get_channel_data().to_dict()
|
||||||
|
self.save_all_channels(channels)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Channel was not found: add a new channel entry
|
||||||
|
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||||
|
channels.append(new_channel.get_channel_data().to_dict())
|
||||||
|
self.save_all_channels(channels)
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
LOG.debug(
|
||||||
|
"Removing channel with path %s from the channel database.",
|
||||||
|
transport_path,
|
||||||
|
)
|
||||||
|
channels = self.read_all_channels()
|
||||||
|
remaining_channels = [
|
||||||
|
ch for ch in channels if ch["transport_path"] != transport_path
|
||||||
|
]
|
||||||
|
self.save_all_channels(remaining_channels)
|
||||||
|
|
||||||
|
|
||||||
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||||
@ -57,63 +124,23 @@ def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ensure_file_exists() -> None:
|
def ensure_file_exists(file_path: str) -> None:
|
||||||
LOG.debug("checking if file %s exists", DATA_PATH)
|
LOG.debug("checking if file %s exists", file_path)
|
||||||
if not os.path.exists(DATA_PATH):
|
if not os.path.exists(file_path):
|
||||||
os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
LOG.debug("File %s does not exist. Creating a new one.", DATA_PATH)
|
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||||
with open(DATA_PATH, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
json.dump([], f)
|
json.dump([], f)
|
||||||
|
|
||||||
|
|
||||||
def clear_stored_channels() -> None:
|
def set_channel_database(should_not_store: bool):
|
||||||
LOG.debug("Clearing contents of %s", DATA_PATH)
|
global db
|
||||||
with open(DATA_PATH, "w") as f:
|
if should_not_store:
|
||||||
json.dump([], f)
|
db = DummyChannelDatabase()
|
||||||
try:
|
else:
|
||||||
os.remove(DATA_PATH)
|
from platformdirs import user_cache_dir
|
||||||
except Exception as e:
|
|
||||||
LOG.exception("Failed to delete %s (%s)", DATA_PATH, str(type(e)))
|
|
||||||
|
|
||||||
|
APP_NAME = "@trezor" # TODO
|
||||||
|
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||||
|
|
||||||
def read_all_channels() -> t.List:
|
db = JsonChannelDatabase(DATA_PATH)
|
||||||
ensure_file_exists()
|
|
||||||
with open(DATA_PATH, "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def save_all_channels(channels: t.List[t.Dict]) -> None:
|
|
||||||
LOG.debug("saving all channels")
|
|
||||||
with open(DATA_PATH, "w") as f:
|
|
||||||
json.dump(channels, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def save_channel(new_channel: ProtocolAndChannel):
|
|
||||||
LOG.debug("save channel")
|
|
||||||
channels = read_all_channels()
|
|
||||||
transport_path = new_channel.transport.get_path()
|
|
||||||
|
|
||||||
# If the channel is found in database: replace the old entry by the new
|
|
||||||
for i, channel in enumerate(channels):
|
|
||||||
if channel["transport_path"] == transport_path:
|
|
||||||
LOG.debug("Modified channel entry for %s", transport_path)
|
|
||||||
channels[i] = new_channel.get_channel_data().to_dict()
|
|
||||||
save_all_channels(channels)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Channel was not found: add a new channel entry
|
|
||||||
LOG.debug("Created a new channel entry on path %s", transport_path)
|
|
||||||
channels.append(new_channel.get_channel_data().to_dict())
|
|
||||||
save_all_channels(channels)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_channel(transport_path: str) -> None:
|
|
||||||
LOG.debug(
|
|
||||||
"Removing channel with path %s from the channel database.",
|
|
||||||
transport_path,
|
|
||||||
)
|
|
||||||
channels = read_all_channels()
|
|
||||||
remaining_channels = [
|
|
||||||
ch for ch in channels if ch["transport_path"] != transport_path
|
|
||||||
]
|
|
||||||
save_all_channels(remaining_channels)
|
|
||||||
|
@ -18,7 +18,8 @@ from ..thp import checksum, curve25519, thp_io
|
|||||||
from ..thp.channel_data import ChannelData
|
from ..thp.channel_data import ChannelData
|
||||||
from ..thp.checksum import CHECKSUM_LENGTH
|
from ..thp.checksum import CHECKSUM_LENGTH
|
||||||
from ..thp.message_header import MessageHeader
|
from ..thp.message_header import MessageHeader
|
||||||
from . import channel_database, control_byte
|
from . import control_byte
|
||||||
|
from .channel_database import ChannelDatabase, get_channel_db
|
||||||
from .protocol_and_channel import ProtocolAndChannel
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -76,6 +77,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||||
self.sync_bit_send = channel_data.sync_bit_send
|
self.sync_bit_send = channel_data.sync_bit_send
|
||||||
self._has_valid_channel = True
|
self._has_valid_channel = True
|
||||||
|
self.channel_database: ChannelDatabase = get_channel_db()
|
||||||
|
|
||||||
def get_channel(self) -> ProtocolV2:
|
def get_channel(self) -> ProtocolV2:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
@ -99,13 +101,13 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||||
if sid != session_id:
|
if sid != session_id:
|
||||||
raise Exception("Received messsage on a different session.")
|
raise Exception("Received messsage on a different session.")
|
||||||
channel_database.save_channel(self)
|
self.channel_database.save_channel(self)
|
||||||
return self.mapping.decode(msg_type, msg_data)
|
return self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
def write(self, session_id: int, msg: t.Any) -> None:
|
def write(self, session_id: int, msg: t.Any) -> None:
|
||||||
msg_type, msg_data = self.mapping.encode(msg)
|
msg_type, msg_data = self.mapping.encode(msg)
|
||||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||||
channel_database.save_channel(self)
|
self.channel_database.save_channel(self)
|
||||||
|
|
||||||
def get_features(self) -> messages.Features:
|
def get_features(self) -> messages.Features:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
|
@ -322,9 +322,9 @@ def client(
|
|||||||
# Get a new client
|
# Get a new client
|
||||||
_raw_client = _get_raw_client(request)
|
_raw_client = _get_raw_client(request)
|
||||||
|
|
||||||
from trezorlib.transport.thp import channel_database
|
from trezorlib.transport.thp.channel_database import get_channel_db
|
||||||
|
|
||||||
channel_database.clear_stored_channels()
|
get_channel_db().clear_stored_channels()
|
||||||
_raw_client.protocol = None
|
_raw_client.protocol = None
|
||||||
_raw_client.__init__(
|
_raw_client.__init__(
|
||||||
transport=_raw_client.transport,
|
transport=_raw_client.transport,
|
||||||
|
Loading…
Reference in New Issue
Block a user