diff --git a/python/channel_data.json b/python/channel_data.json new file mode 100644 index 000000000..0637a088a --- /dev/null +++ b/python/channel_data.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index be9cc03e6..3897b5bce 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -27,6 +27,7 @@ import click from .. import __version__, log, messages, protobuf from ..client import TrezorClient from ..transport import DeviceIsBusy, new_enumerate_devices +from ..transport.new import channel_database from ..transport.new.client import NewTrezorClient from ..transport.udp import UdpTransport from . import ( @@ -287,18 +288,32 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]: if no_resolve: return new_enumerate_devices() + stored_channels = channel_database.load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] for transport in new_enumerate_devices(): try: - client = NewTrezorClient(transport) - session = client.get_management_session() + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + client = NewTrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + else: + client = NewTrezorClient(transport) + session = client.get_management_session() description = format_device_name(session.features) + # json_string = channel_database.channel_to_str(client.protocol) + # print(json_string) + channel_database.save_channel(client.protocol) # client.end_session() except DeviceIsBusy: description = "Device is in use by another process" except Exception: description = "Failed to read details" - click.echo(f"{transport} - {description}") + click.echo(f"{transport.get_path()} - {description}") return None diff --git a/python/src/trezorlib/transport/new/channel_data.py b/python/src/trezorlib/transport/new/channel_data.py index 08cdfe35d..3d70deeca 100644 --- a/python/src/trezorlib/transport/new/channel_data.py +++ b/python/src/trezorlib/transport/new/channel_data.py @@ -1,11 +1,40 @@ from __future__ import annotations +from binascii import hexlify + class ChannelData: - key_request: bytes - key_response: bytes - nonce_request: int - nonce_response: int - channel_id: bytes - sync_bit_send: int - sync_bit_receive: int + def __init__( + self, + protocol_version: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + ) -> None: + self.protocol_version: int = protocol_version + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + + def to_dict(self): + return { + "protocol_version": self.protocol_version, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + } diff --git a/python/src/trezorlib/transport/new/channel_database.py b/python/src/trezorlib/transport/new/channel_database.py new file mode 100644 index 000000000..0fc7c687d --- /dev/null +++ b/python/src/trezorlib/transport/new/channel_database.py @@ -0,0 +1,85 @@ +import json +import logging +import os +import typing as t + +from .channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +FILE_PATH = "channel_data.json" + + +def load_stored_channels() -> t.List[ChannelData]: + dicts = read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + +def channel_to_str(channel: ProtocolAndChannel) -> str: + return json.dumps(channel.get_channel_data().to_dict()) + + +def str_to_channel_data(channel_data: str) -> ChannelData: + return dict_to_channel_data(json.loads(channel_data)) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version=dict["protocol_version"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + ) + + +def ensure_file_exists() -> None: + LOG.debug("checking if file %s exists", FILE_PATH) + if not os.path.exists(FILE_PATH): + LOG.debug("File %s does not exist. Creating a new one.", FILE_PATH) + with open(FILE_PATH, "w") as f: + json.dump([], f) + + +def read_all_channels() -> t.List: + ensure_file_exists() + with open(FILE_PATH, "r") as f: + return json.load(f) + + +def save_all_channels(channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(FILE_PATH, "w") as f: + json.dump(channels, f, indent=4) + + +def save_channel(new_channel: ProtocolAndChannel): + LOG.debug("save channel") + channels = read_all_channels() + transport_path = new_channel.transport.get_path() + + # If channel is modified: replace the old by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + save_all_channels(channels) + return + + # Else: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + save_all_channels(channels) + + +def remove_channel(transport_path: str) -> None: + channels = read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + save_all_channels(remaining_channels) diff --git a/python/src/trezorlib/transport/new/client.py b/python/src/trezorlib/transport/new/client.py index 115bf5987..151a6a743 100644 --- a/python/src/trezorlib/transport/new/client.py +++ b/python/src/trezorlib/transport/new/client.py @@ -33,11 +33,22 @@ class NewTrezorClient: self.protocol = self._get_protocol() else: self.protocol = protocol + self.protocol.mapping = self.mapping @classmethod def resume( - cls, transport: NewTransport, channel_data: ChannelData - ) -> NewTrezorClient: ... + cls, + transport: NewTransport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> NewTrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + if channel_data.protocol_version == 2: + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return NewTrezorClient(transport, protobuf_mapping, protocol) def get_session( self, @@ -59,7 +70,6 @@ class NewTrezorClient: self.management_session = SessionV1.new(self, "", False) elif isinstance(self.protocol, ProtocolV2): self.management_session = SessionV2(self, b"\x00") - assert self.management_session is not None return self.management_session diff --git a/python/src/trezorlib/transport/new/protocol_and_channel.py b/python/src/trezorlib/transport/new/protocol_and_channel.py index 011be6c37..8b03d2853 100644 --- a/python/src/trezorlib/transport/new/protocol_and_channel.py +++ b/python/src/trezorlib/transport/new/protocol_and_channel.py @@ -13,15 +13,16 @@ LOG = logging.getLogger(__name__) class ProtocolAndChannel: + def __init__( self, transport: NewTransport, mapping: ProtobufMapping, - channel_keys: ChannelData | None = None, + channel_data: ChannelData | None = None, ) -> None: self.transport = transport self.mapping = mapping - self.channel_keys = channel_keys + self.channel_keys = channel_data def close(self) -> None: ... @@ -29,7 +30,8 @@ class ProtocolAndChannel: # def read(self, session_id: bytes) -> t.Any: ... - def get_channel_keys(self) -> ChannelData: ... + def get_channel_data(self) -> ChannelData: + raise NotImplementedError class ProtocolV1(ProtocolAndChannel): @@ -101,15 +103,3 @@ class ProtocolV1(ProtocolAndChannel): if chunk[:1] != b"?": raise RuntimeError("Unexpected magic characters") return chunk[1:] - - -class Channel: - id: int - channel_keys: ChannelData | None - - def __init__(self, id: int, keys: ChannelData) -> None: - self.id = id - self.channel_keys = keys - - def read(self) -> t.Any: ... - def write(self, msg: t.Any) -> None: ... diff --git a/python/src/trezorlib/transport/new/protocol_v2.py b/python/src/trezorlib/transport/new/protocol_v2.py index f6fc47306..9e0cc8f1b 100644 --- a/python/src/trezorlib/transport/new/protocol_v2.py +++ b/python/src/trezorlib/transport/new/protocol_v2.py @@ -17,7 +17,7 @@ from ..thp.checksum import CHECKSUM_LENGTH from ..thp.packet_header import PacketHeader from . import control_byte from .channel_data import ChannelData -from .protocol_and_channel import Channel, ProtocolAndChannel +from .protocol_and_channel import ProtocolAndChannel from .transport import NewTransport LOG = logging.getLogger(__name__) @@ -47,31 +47,56 @@ def _get_iv_from_nonce(nonce: int) -> bytes: class ProtocolV2(ProtocolAndChannel): + channel_id: int + key_request: bytes key_response: bytes nonce_request: int nonce_response: int - channel_id: int sync_bit_send: int sync_bit_receive: int + has_valid_channel: bool = False + has_valid_features: bool = False features: messages.Features def __init__( self, transport: NewTransport, mapping: ProtobufMapping, - channel_keys: ChannelData | None = None, + channel_data: ChannelData | None = None, ) -> None: - super().__init__(transport, mapping, channel_keys) - self.channel: Channel | None = None + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self.has_valid_channel = True def get_channel(self) -> ProtocolV2: if not self.has_valid_channel: self._establish_new_channel() + if not self.has_valid_features: self.update_features() return self + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + ) + def read(self, session_id: int) -> t.Any: header, data = self._read_until_valid_crc_check() # TODO @@ -83,7 +108,6 @@ class ProtocolV2(ProtocolAndChannel): def update_features(self) -> None: message = messages.GetFeatures() message_type, message_data = self.mapping.encode(message) - self.session_id: int = 0 self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) _ = self._read_until_valid_crc_check() # TODO check ACK @@ -91,6 +115,7 @@ class ProtocolV2(ProtocolAndChannel): features = self.mapping.decode(msg_type, msg_data) assert isinstance(features, messages.Features) self.features = features + self.has_valid_features = True def _establish_new_channel(self) -> None: self.sync_bit_send = 0 diff --git a/python/src/trezorlib/transport/new/session.py b/python/src/trezorlib/transport/new/session.py index 5e825802e..349850937 100644 --- a/python/src/trezorlib/transport/new/session.py +++ b/python/src/trezorlib/transport/new/session.py @@ -70,8 +70,11 @@ class SessionV2(Session): def __init__(self, client: NewTrezorClient, id: bytes) -> None: super().__init__(client, id) assert isinstance(client.protocol, ProtocolV2) + self.channel: ProtocolV2 = client.protocol.get_channel() self.update_id_and_sid(id) + if not self.channel.has_valid_features: + self.channel.update_features() self.features = self.channel.features def call(self, msg: t.Any) -> t.Any: