mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
wip trezorlib channel persistence
This commit is contained in:
parent
8b3bfe648c
commit
1b871fd01c
1
python/channel_data.json
Normal file
1
python/channel_data.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
[]
|
@ -27,6 +27,7 @@ import click
|
|||||||
from .. import __version__, log, messages, protobuf
|
from .. import __version__, log, messages, protobuf
|
||||||
from ..client import TrezorClient
|
from ..client import TrezorClient
|
||||||
from ..transport import DeviceIsBusy, new_enumerate_devices
|
from ..transport import DeviceIsBusy, new_enumerate_devices
|
||||||
|
from ..transport.new import channel_database
|
||||||
from ..transport.new.client import NewTrezorClient
|
from ..transport.new.client import NewTrezorClient
|
||||||
from ..transport.udp import UdpTransport
|
from ..transport.udp import UdpTransport
|
||||||
from . import (
|
from . import (
|
||||||
@ -287,18 +288,32 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
|
|||||||
if no_resolve:
|
if no_resolve:
|
||||||
return new_enumerate_devices()
|
return new_enumerate_devices()
|
||||||
|
|
||||||
|
stored_channels = channel_database.load_stored_channels()
|
||||||
|
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||||
for transport in new_enumerate_devices():
|
for transport in new_enumerate_devices():
|
||||||
try:
|
try:
|
||||||
client = NewTrezorClient(transport)
|
path = transport.get_path()
|
||||||
session = client.get_management_session()
|
if path in stored_transport_paths:
|
||||||
|
stored_channel_with_correct_transport_path = next(
|
||||||
|
ch for ch in stored_channels if ch.transport_path == path
|
||||||
|
)
|
||||||
|
client = NewTrezorClient.resume(
|
||||||
|
transport, stored_channel_with_correct_transport_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client = NewTrezorClient(transport)
|
||||||
|
|
||||||
|
session = client.get_management_session()
|
||||||
description = format_device_name(session.features)
|
description = format_device_name(session.features)
|
||||||
|
# json_string = channel_database.channel_to_str(client.protocol)
|
||||||
|
# print(json_string)
|
||||||
|
channel_database.save_channel(client.protocol)
|
||||||
# client.end_session()
|
# client.end_session()
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception:
|
||||||
description = "Failed to read details"
|
description = "Failed to read details"
|
||||||
click.echo(f"{transport} - {description}")
|
click.echo(f"{transport.get_path()} - {description}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,40 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
|
||||||
class ChannelData:
|
class ChannelData:
|
||||||
key_request: bytes
|
def __init__(
|
||||||
key_response: bytes
|
self,
|
||||||
nonce_request: int
|
protocol_version: int,
|
||||||
nonce_response: int
|
transport_path: str,
|
||||||
channel_id: bytes
|
channel_id: int,
|
||||||
sync_bit_send: int
|
key_request: bytes,
|
||||||
sync_bit_receive: int
|
key_response: bytes,
|
||||||
|
nonce_request: int,
|
||||||
|
nonce_response: int,
|
||||||
|
sync_bit_send: int,
|
||||||
|
sync_bit_receive: int,
|
||||||
|
) -> None:
|
||||||
|
self.protocol_version: int = protocol_version
|
||||||
|
self.transport_path: str = transport_path
|
||||||
|
self.channel_id: int = channel_id
|
||||||
|
self.key_request: str = hexlify(key_request).decode()
|
||||||
|
self.key_response: str = hexlify(key_response).decode()
|
||||||
|
self.nonce_request: int = nonce_request
|
||||||
|
self.nonce_response: int = nonce_response
|
||||||
|
self.sync_bit_receive: int = sync_bit_receive
|
||||||
|
self.sync_bit_send: int = sync_bit_send
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"protocol_version": self.protocol_version,
|
||||||
|
"transport_path": self.transport_path,
|
||||||
|
"channel_id": self.channel_id,
|
||||||
|
"key_request": self.key_request,
|
||||||
|
"key_response": self.key_response,
|
||||||
|
"nonce_request": self.nonce_request,
|
||||||
|
"nonce_response": self.nonce_response,
|
||||||
|
"sync_bit_send": self.sync_bit_send,
|
||||||
|
"sync_bit_receive": self.sync_bit_receive,
|
||||||
|
}
|
||||||
|
85
python/src/trezorlib/transport/new/channel_database.py
Normal file
85
python/src/trezorlib/transport/new/channel_database.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from .channel_data import ChannelData
|
||||||
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FILE_PATH = "channel_data.json"
|
||||||
|
|
||||||
|
|
||||||
|
def load_stored_channels() -> t.List[ChannelData]:
|
||||||
|
dicts = read_all_channels()
|
||||||
|
return [dict_to_channel_data(d) for d in dicts]
|
||||||
|
|
||||||
|
|
||||||
|
def channel_to_str(channel: ProtocolAndChannel) -> str:
|
||||||
|
return json.dumps(channel.get_channel_data().to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_channel_data(channel_data: str) -> ChannelData:
|
||||||
|
return dict_to_channel_data(json.loads(channel_data))
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version=dict["protocol_version"],
|
||||||
|
transport_path=dict["transport_path"],
|
||||||
|
channel_id=dict["channel_id"],
|
||||||
|
key_request=bytes.fromhex(dict["key_request"]),
|
||||||
|
key_response=bytes.fromhex(dict["key_response"]),
|
||||||
|
nonce_request=dict["nonce_request"],
|
||||||
|
nonce_response=dict["nonce_response"],
|
||||||
|
sync_bit_send=dict["sync_bit_send"],
|
||||||
|
sync_bit_receive=dict["sync_bit_receive"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_file_exists() -> None:
|
||||||
|
LOG.debug("checking if file %s exists", FILE_PATH)
|
||||||
|
if not os.path.exists(FILE_PATH):
|
||||||
|
LOG.debug("File %s does not exist. Creating a new one.", FILE_PATH)
|
||||||
|
with open(FILE_PATH, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
|
||||||
|
|
||||||
|
def read_all_channels() -> t.List:
|
||||||
|
ensure_file_exists()
|
||||||
|
with open(FILE_PATH, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def save_all_channels(channels: t.List[t.Dict]) -> None:
|
||||||
|
LOG.debug("saving all channels")
|
||||||
|
with open(FILE_PATH, "w") as f:
|
||||||
|
json.dump(channels, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def save_channel(new_channel: ProtocolAndChannel):
|
||||||
|
LOG.debug("save channel")
|
||||||
|
channels = read_all_channels()
|
||||||
|
transport_path = new_channel.transport.get_path()
|
||||||
|
|
||||||
|
# If channel is modified: replace the old by the new
|
||||||
|
for i, channel in enumerate(channels):
|
||||||
|
if channel["transport_path"] == transport_path:
|
||||||
|
LOG.debug("Modified channel entry for %s", transport_path)
|
||||||
|
channels[i] = new_channel.get_channel_data().to_dict()
|
||||||
|
save_all_channels(channels)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Else: add a new channel entry
|
||||||
|
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||||
|
channels.append(new_channel.get_channel_data().to_dict())
|
||||||
|
save_all_channels(channels)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_channel(transport_path: str) -> None:
|
||||||
|
channels = read_all_channels()
|
||||||
|
remaining_channels = [
|
||||||
|
ch for ch in channels if ch["transport_path"] != transport_path
|
||||||
|
]
|
||||||
|
save_all_channels(remaining_channels)
|
@ -33,11 +33,22 @@ class NewTrezorClient:
|
|||||||
self.protocol = self._get_protocol()
|
self.protocol = self._get_protocol()
|
||||||
else:
|
else:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
|
self.protocol.mapping = self.mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resume(
|
def resume(
|
||||||
cls, transport: NewTransport, channel_data: ChannelData
|
cls,
|
||||||
) -> NewTrezorClient: ...
|
transport: NewTransport,
|
||||||
|
channel_data: ChannelData,
|
||||||
|
protobuf_mapping: ProtobufMapping | None = None,
|
||||||
|
) -> NewTrezorClient:
|
||||||
|
if protobuf_mapping is None:
|
||||||
|
protobuf_mapping = mapping.DEFAULT_MAPPING
|
||||||
|
if channel_data.protocol_version == 2:
|
||||||
|
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
|
||||||
|
else:
|
||||||
|
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
|
||||||
|
return NewTrezorClient(transport, protobuf_mapping, protocol)
|
||||||
|
|
||||||
def get_session(
|
def get_session(
|
||||||
self,
|
self,
|
||||||
@ -59,7 +70,6 @@ class NewTrezorClient:
|
|||||||
self.management_session = SessionV1.new(self, "", False)
|
self.management_session = SessionV1.new(self, "", False)
|
||||||
elif isinstance(self.protocol, ProtocolV2):
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
self.management_session = SessionV2(self, b"\x00")
|
self.management_session = SessionV2(self, b"\x00")
|
||||||
|
|
||||||
assert self.management_session is not None
|
assert self.management_session is not None
|
||||||
return self.management_session
|
return self.management_session
|
||||||
|
|
||||||
|
@ -13,15 +13,16 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ProtocolAndChannel:
|
class ProtocolAndChannel:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: NewTransport,
|
transport: NewTransport,
|
||||||
mapping: ProtobufMapping,
|
mapping: ProtobufMapping,
|
||||||
channel_keys: ChannelData | None = None,
|
channel_data: ChannelData | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.mapping = mapping
|
self.mapping = mapping
|
||||||
self.channel_keys = channel_keys
|
self.channel_keys = channel_data
|
||||||
|
|
||||||
def close(self) -> None: ...
|
def close(self) -> None: ...
|
||||||
|
|
||||||
@ -29,7 +30,8 @@ class ProtocolAndChannel:
|
|||||||
|
|
||||||
# def read(self, session_id: bytes) -> t.Any: ...
|
# def read(self, session_id: bytes) -> t.Any: ...
|
||||||
|
|
||||||
def get_channel_keys(self) -> ChannelData: ...
|
def get_channel_data(self) -> ChannelData:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV1(ProtocolAndChannel):
|
class ProtocolV1(ProtocolAndChannel):
|
||||||
@ -101,15 +103,3 @@ class ProtocolV1(ProtocolAndChannel):
|
|||||||
if chunk[:1] != b"?":
|
if chunk[:1] != b"?":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
return chunk[1:]
|
return chunk[1:]
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
|
||||||
id: int
|
|
||||||
channel_keys: ChannelData | None
|
|
||||||
|
|
||||||
def __init__(self, id: int, keys: ChannelData) -> None:
|
|
||||||
self.id = id
|
|
||||||
self.channel_keys = keys
|
|
||||||
|
|
||||||
def read(self) -> t.Any: ...
|
|
||||||
def write(self, msg: t.Any) -> None: ...
|
|
||||||
|
@ -17,7 +17,7 @@ from ..thp.checksum import CHECKSUM_LENGTH
|
|||||||
from ..thp.packet_header import PacketHeader
|
from ..thp.packet_header import PacketHeader
|
||||||
from . import control_byte
|
from . import control_byte
|
||||||
from .channel_data import ChannelData
|
from .channel_data import ChannelData
|
||||||
from .protocol_and_channel import Channel, ProtocolAndChannel
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
from .transport import NewTransport
|
from .transport import NewTransport
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -47,31 +47,56 @@ def _get_iv_from_nonce(nonce: int) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
class ProtocolV2(ProtocolAndChannel):
|
class ProtocolV2(ProtocolAndChannel):
|
||||||
|
channel_id: int
|
||||||
|
|
||||||
key_request: bytes
|
key_request: bytes
|
||||||
key_response: bytes
|
key_response: bytes
|
||||||
nonce_request: int
|
nonce_request: int
|
||||||
nonce_response: int
|
nonce_response: int
|
||||||
channel_id: int
|
|
||||||
sync_bit_send: int
|
sync_bit_send: int
|
||||||
sync_bit_receive: int
|
sync_bit_receive: int
|
||||||
|
|
||||||
has_valid_channel: bool = False
|
has_valid_channel: bool = False
|
||||||
|
has_valid_features: bool = False
|
||||||
features: messages.Features
|
features: messages.Features
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: NewTransport,
|
transport: NewTransport,
|
||||||
mapping: ProtobufMapping,
|
mapping: ProtobufMapping,
|
||||||
channel_keys: ChannelData | None = None,
|
channel_data: ChannelData | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(transport, mapping, channel_keys)
|
super().__init__(transport, mapping, channel_data)
|
||||||
self.channel: Channel | None = None
|
if channel_data is not None:
|
||||||
|
self.channel_id = channel_data.channel_id
|
||||||
|
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||||
|
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||||
|
self.nonce_request = channel_data.nonce_request
|
||||||
|
self.nonce_response = channel_data.nonce_response
|
||||||
|
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||||
|
self.sync_bit_send = channel_data.sync_bit_send
|
||||||
|
self.has_valid_channel = True
|
||||||
|
|
||||||
def get_channel(self) -> ProtocolV2:
|
def get_channel(self) -> ProtocolV2:
|
||||||
if not self.has_valid_channel:
|
if not self.has_valid_channel:
|
||||||
self._establish_new_channel()
|
self._establish_new_channel()
|
||||||
|
if not self.has_valid_features:
|
||||||
self.update_features()
|
self.update_features()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_channel_data(self) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version=2,
|
||||||
|
transport_path=self.transport.get_path(),
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
key_request=self.key_request,
|
||||||
|
key_response=self.key_response,
|
||||||
|
nonce_request=self.nonce_request,
|
||||||
|
nonce_response=self.nonce_response,
|
||||||
|
sync_bit_receive=self.sync_bit_receive,
|
||||||
|
sync_bit_send=self.sync_bit_send,
|
||||||
|
)
|
||||||
|
|
||||||
def read(self, session_id: int) -> t.Any:
|
def read(self, session_id: int) -> t.Any:
|
||||||
header, data = self._read_until_valid_crc_check()
|
header, data = self._read_until_valid_crc_check()
|
||||||
# TODO
|
# TODO
|
||||||
@ -83,7 +108,6 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
def update_features(self) -> None:
|
def update_features(self) -> None:
|
||||||
message = messages.GetFeatures()
|
message = messages.GetFeatures()
|
||||||
message_type, message_data = self.mapping.encode(message)
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
|
||||||
self.session_id: int = 0
|
self.session_id: int = 0
|
||||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||||
@ -91,6 +115,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
features = self.mapping.decode(msg_type, msg_data)
|
features = self.mapping.decode(msg_type, msg_data)
|
||||||
assert isinstance(features, messages.Features)
|
assert isinstance(features, messages.Features)
|
||||||
self.features = features
|
self.features = features
|
||||||
|
self.has_valid_features = True
|
||||||
|
|
||||||
def _establish_new_channel(self) -> None:
|
def _establish_new_channel(self) -> None:
|
||||||
self.sync_bit_send = 0
|
self.sync_bit_send = 0
|
||||||
|
@ -70,8 +70,11 @@ class SessionV2(Session):
|
|||||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||||
super().__init__(client, id)
|
super().__init__(client, id)
|
||||||
assert isinstance(client.protocol, ProtocolV2)
|
assert isinstance(client.protocol, ProtocolV2)
|
||||||
|
|
||||||
self.channel: ProtocolV2 = client.protocol.get_channel()
|
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||||
self.update_id_and_sid(id)
|
self.update_id_and_sid(id)
|
||||||
|
if not self.channel.has_valid_features:
|
||||||
|
self.channel.update_features()
|
||||||
self.features = self.channel.features
|
self.features = self.channel.features
|
||||||
|
|
||||||
def call(self, msg: t.Any) -> t.Any:
|
def call(self, msg: t.Any) -> t.Any:
|
||||||
|
Loading…
Reference in New Issue
Block a user