mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-13 02:58:57 +00:00
wip trezorlib channel persistence
This commit is contained in:
parent
8b3bfe648c
commit
1b871fd01c
1
python/channel_data.json
Normal file
1
python/channel_data.json
Normal file
@ -0,0 +1 @@
|
||||
[]
|
@ -27,6 +27,7 @@ import click
|
||||
from .. import __version__, log, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..transport import DeviceIsBusy, new_enumerate_devices
|
||||
from ..transport.new import channel_database
|
||||
from ..transport.new.client import NewTrezorClient
|
||||
from ..transport.udp import UdpTransport
|
||||
from . import (
|
||||
@ -287,18 +288,32 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
|
||||
if no_resolve:
|
||||
return new_enumerate_devices()
|
||||
|
||||
stored_channels = channel_database.load_stored_channels()
|
||||
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||
for transport in new_enumerate_devices():
|
||||
try:
|
||||
client = NewTrezorClient(transport)
|
||||
session = client.get_management_session()
|
||||
path = transport.get_path()
|
||||
if path in stored_transport_paths:
|
||||
stored_channel_with_correct_transport_path = next(
|
||||
ch for ch in stored_channels if ch.transport_path == path
|
||||
)
|
||||
client = NewTrezorClient.resume(
|
||||
transport, stored_channel_with_correct_transport_path
|
||||
)
|
||||
else:
|
||||
client = NewTrezorClient(transport)
|
||||
|
||||
session = client.get_management_session()
|
||||
description = format_device_name(session.features)
|
||||
# json_string = channel_database.channel_to_str(client.protocol)
|
||||
# print(json_string)
|
||||
channel_database.save_channel(client.protocol)
|
||||
# client.end_session()
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
except Exception:
|
||||
description = "Failed to read details"
|
||||
click.echo(f"{transport} - {description}")
|
||||
click.echo(f"{transport.get_path()} - {description}")
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,11 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from binascii import hexlify
|
||||
|
||||
|
||||
class ChannelData:
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: bytes
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
def __init__(
|
||||
self,
|
||||
protocol_version: int,
|
||||
transport_path: str,
|
||||
channel_id: int,
|
||||
key_request: bytes,
|
||||
key_response: bytes,
|
||||
nonce_request: int,
|
||||
nonce_response: int,
|
||||
sync_bit_send: int,
|
||||
sync_bit_receive: int,
|
||||
) -> None:
|
||||
self.protocol_version: int = protocol_version
|
||||
self.transport_path: str = transport_path
|
||||
self.channel_id: int = channel_id
|
||||
self.key_request: str = hexlify(key_request).decode()
|
||||
self.key_response: str = hexlify(key_response).decode()
|
||||
self.nonce_request: int = nonce_request
|
||||
self.nonce_response: int = nonce_response
|
||||
self.sync_bit_receive: int = sync_bit_receive
|
||||
self.sync_bit_send: int = sync_bit_send
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"protocol_version": self.protocol_version,
|
||||
"transport_path": self.transport_path,
|
||||
"channel_id": self.channel_id,
|
||||
"key_request": self.key_request,
|
||||
"key_response": self.key_response,
|
||||
"nonce_request": self.nonce_request,
|
||||
"nonce_response": self.nonce_response,
|
||||
"sync_bit_send": self.sync_bit_send,
|
||||
"sync_bit_receive": self.sync_bit_receive,
|
||||
}
|
||||
|
85
python/src/trezorlib/transport/new/channel_database.py
Normal file
85
python/src/trezorlib/transport/new/channel_database.py
Normal file
@ -0,0 +1,85 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
FILE_PATH = "channel_data.json"
|
||||
|
||||
|
||||
def load_stored_channels() -> t.List[ChannelData]:
|
||||
dicts = read_all_channels()
|
||||
return [dict_to_channel_data(d) for d in dicts]
|
||||
|
||||
|
||||
def channel_to_str(channel: ProtocolAndChannel) -> str:
|
||||
return json.dumps(channel.get_channel_data().to_dict())
|
||||
|
||||
|
||||
def str_to_channel_data(channel_data: str) -> ChannelData:
|
||||
return dict_to_channel_data(json.loads(channel_data))
|
||||
|
||||
|
||||
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version=dict["protocol_version"],
|
||||
transport_path=dict["transport_path"],
|
||||
channel_id=dict["channel_id"],
|
||||
key_request=bytes.fromhex(dict["key_request"]),
|
||||
key_response=bytes.fromhex(dict["key_response"]),
|
||||
nonce_request=dict["nonce_request"],
|
||||
nonce_response=dict["nonce_response"],
|
||||
sync_bit_send=dict["sync_bit_send"],
|
||||
sync_bit_receive=dict["sync_bit_receive"],
|
||||
)
|
||||
|
||||
|
||||
def ensure_file_exists() -> None:
|
||||
LOG.debug("checking if file %s exists", FILE_PATH)
|
||||
if not os.path.exists(FILE_PATH):
|
||||
LOG.debug("File %s does not exist. Creating a new one.", FILE_PATH)
|
||||
with open(FILE_PATH, "w") as f:
|
||||
json.dump([], f)
|
||||
|
||||
|
||||
def read_all_channels() -> t.List:
|
||||
ensure_file_exists()
|
||||
with open(FILE_PATH, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_all_channels(channels: t.List[t.Dict]) -> None:
|
||||
LOG.debug("saving all channels")
|
||||
with open(FILE_PATH, "w") as f:
|
||||
json.dump(channels, f, indent=4)
|
||||
|
||||
|
||||
def save_channel(new_channel: ProtocolAndChannel):
|
||||
LOG.debug("save channel")
|
||||
channels = read_all_channels()
|
||||
transport_path = new_channel.transport.get_path()
|
||||
|
||||
# If channel is modified: replace the old by the new
|
||||
for i, channel in enumerate(channels):
|
||||
if channel["transport_path"] == transport_path:
|
||||
LOG.debug("Modified channel entry for %s", transport_path)
|
||||
channels[i] = new_channel.get_channel_data().to_dict()
|
||||
save_all_channels(channels)
|
||||
return
|
||||
|
||||
# Else: add a new channel entry
|
||||
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||
channels.append(new_channel.get_channel_data().to_dict())
|
||||
save_all_channels(channels)
|
||||
|
||||
|
||||
def remove_channel(transport_path: str) -> None:
|
||||
channels = read_all_channels()
|
||||
remaining_channels = [
|
||||
ch for ch in channels if ch["transport_path"] != transport_path
|
||||
]
|
||||
save_all_channels(remaining_channels)
|
@ -33,11 +33,22 @@ class NewTrezorClient:
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
cls, transport: NewTransport, channel_data: ChannelData
|
||||
) -> NewTrezorClient: ...
|
||||
cls,
|
||||
transport: NewTransport,
|
||||
channel_data: ChannelData,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
) -> NewTrezorClient:
|
||||
if protobuf_mapping is None:
|
||||
protobuf_mapping = mapping.DEFAULT_MAPPING
|
||||
if channel_data.protocol_version == 2:
|
||||
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
|
||||
else:
|
||||
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
|
||||
return NewTrezorClient(transport, protobuf_mapping, protocol)
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
@ -59,7 +70,6 @@ class NewTrezorClient:
|
||||
self.management_session = SessionV1.new(self, "", False)
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self.management_session = SessionV2(self, b"\x00")
|
||||
|
||||
assert self.management_session is not None
|
||||
return self.management_session
|
||||
|
||||
|
@ -13,15 +13,16 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolAndChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
self.channel_keys = channel_keys
|
||||
self.channel_keys = channel_data
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
@ -29,7 +30,8 @@ class ProtocolAndChannel:
|
||||
|
||||
# def read(self, session_id: bytes) -> t.Any: ...
|
||||
|
||||
def get_channel_keys(self) -> ChannelData: ...
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
@ -101,15 +103,3 @@ class ProtocolV1(ProtocolAndChannel):
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class Channel:
|
||||
id: int
|
||||
channel_keys: ChannelData | None
|
||||
|
||||
def __init__(self, id: int, keys: ChannelData) -> None:
|
||||
self.id = id
|
||||
self.channel_keys = keys
|
||||
|
||||
def read(self) -> t.Any: ...
|
||||
def write(self, msg: t.Any) -> None: ...
|
||||
|
@ -17,7 +17,7 @@ from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.packet_header import PacketHeader
|
||||
from . import control_byte
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import Channel, ProtocolAndChannel
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -47,31 +47,56 @@ def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
|
||||
|
||||
class ProtocolV2(ProtocolAndChannel):
|
||||
channel_id: int
|
||||
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
|
||||
has_valid_channel: bool = False
|
||||
has_valid_features: bool = False
|
||||
features: messages.Features
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
super().__init__(transport, mapping, channel_keys)
|
||||
self.channel: Channel | None = None
|
||||
super().__init__(transport, mapping, channel_data)
|
||||
if channel_data is not None:
|
||||
self.channel_id = channel_data.channel_id
|
||||
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||
self.nonce_request = channel_data.nonce_request
|
||||
self.nonce_response = channel_data.nonce_response
|
||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||
self.sync_bit_send = channel_data.sync_bit_send
|
||||
self.has_valid_channel = True
|
||||
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self.has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
if not self.has_valid_features:
|
||||
self.update_features()
|
||||
return self
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version=2,
|
||||
transport_path=self.transport.get_path(),
|
||||
channel_id=self.channel_id,
|
||||
key_request=self.key_request,
|
||||
key_response=self.key_response,
|
||||
nonce_request=self.nonce_request,
|
||||
nonce_response=self.nonce_response,
|
||||
sync_bit_receive=self.sync_bit_receive,
|
||||
sync_bit_send=self.sync_bit_send,
|
||||
)
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
header, data = self._read_until_valid_crc_check()
|
||||
# TODO
|
||||
@ -83,7 +108,6 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
def update_features(self) -> None:
|
||||
message = messages.GetFeatures()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
|
||||
self.session_id: int = 0
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||
@ -91,6 +115,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
features = self.mapping.decode(msg_type, msg_data)
|
||||
assert isinstance(features, messages.Features)
|
||||
self.features = features
|
||||
self.has_valid_features = True
|
||||
|
||||
def _establish_new_channel(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
|
@ -70,8 +70,11 @@ class SessionV2(Session):
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
|
||||
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||
self.update_id_and_sid(id)
|
||||
if not self.channel.has_valid_features:
|
||||
self.channel.update_features()
|
||||
self.features = self.channel.features
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
|
Loading…
Reference in New Issue
Block a user