1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

wip trezorlib channel persistence

This commit is contained in:
M1nd3r 2024-09-12 17:03:39 +02:00
parent 8b3bfe648c
commit 1b871fd01c
8 changed files with 192 additions and 34 deletions

1
python/channel_data.json Normal file
View File

@ -0,0 +1 @@
[]

View File

@ -27,6 +27,7 @@ import click
from .. import __version__, log, messages, protobuf from .. import __version__, log, messages, protobuf
from ..client import TrezorClient from ..client import TrezorClient
from ..transport import DeviceIsBusy, new_enumerate_devices from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new import channel_database
from ..transport.new.client import NewTrezorClient from ..transport.new.client import NewTrezorClient
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
@ -287,18 +288,32 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
if no_resolve: if no_resolve:
return new_enumerate_devices() return new_enumerate_devices()
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
for transport in new_enumerate_devices(): for transport in new_enumerate_devices():
try: try:
client = NewTrezorClient(transport) path = transport.get_path()
session = client.get_management_session() if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
client = NewTrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
else:
client = NewTrezorClient(transport)
session = client.get_management_session()
description = format_device_name(session.features) description = format_device_name(session.features)
# json_string = channel_database.channel_to_str(client.protocol)
# print(json_string)
channel_database.save_channel(client.protocol)
# client.end_session() # client.end_session()
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception:
description = "Failed to read details" description = "Failed to read details"
click.echo(f"{transport} - {description}") click.echo(f"{transport.get_path()} - {description}")
return None return None

View File

@ -1,11 +1,40 @@
from __future__ import annotations from __future__ import annotations
from binascii import hexlify
class ChannelData: class ChannelData:
key_request: bytes def __init__(
key_response: bytes self,
nonce_request: int protocol_version: int,
nonce_response: int transport_path: str,
channel_id: bytes channel_id: int,
sync_bit_send: int key_request: bytes,
sync_bit_receive: int key_response: bytes,
nonce_request: int,
nonce_response: int,
sync_bit_send: int,
sync_bit_receive: int,
) -> None:
self.protocol_version: int = protocol_version
self.transport_path: str = transport_path
self.channel_id: int = channel_id
self.key_request: str = hexlify(key_request).decode()
self.key_response: str = hexlify(key_response).decode()
self.nonce_request: int = nonce_request
self.nonce_response: int = nonce_response
self.sync_bit_receive: int = sync_bit_receive
self.sync_bit_send: int = sync_bit_send
def to_dict(self):
return {
"protocol_version": self.protocol_version,
"transport_path": self.transport_path,
"channel_id": self.channel_id,
"key_request": self.key_request,
"key_response": self.key_response,
"nonce_request": self.nonce_request,
"nonce_response": self.nonce_response,
"sync_bit_send": self.sync_bit_send,
"sync_bit_receive": self.sync_bit_receive,
}

View File

@ -0,0 +1,85 @@
import json
import logging
import os
import typing as t
from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
FILE_PATH = "channel_data.json"
def load_stored_channels() -> t.List[ChannelData]:
dicts = read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def channel_to_str(channel: ProtocolAndChannel) -> str:
return json.dumps(channel.get_channel_data().to_dict())
def str_to_channel_data(channel_data: str) -> ChannelData:
return dict_to_channel_data(json.loads(channel_data))
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
return ChannelData(
protocol_version=dict["protocol_version"],
transport_path=dict["transport_path"],
channel_id=dict["channel_id"],
key_request=bytes.fromhex(dict["key_request"]),
key_response=bytes.fromhex(dict["key_response"]),
nonce_request=dict["nonce_request"],
nonce_response=dict["nonce_response"],
sync_bit_send=dict["sync_bit_send"],
sync_bit_receive=dict["sync_bit_receive"],
)
def ensure_file_exists() -> None:
LOG.debug("checking if file %s exists", FILE_PATH)
if not os.path.exists(FILE_PATH):
LOG.debug("File %s does not exist. Creating a new one.", FILE_PATH)
with open(FILE_PATH, "w") as f:
json.dump([], f)
def read_all_channels() -> t.List:
ensure_file_exists()
with open(FILE_PATH, "r") as f:
return json.load(f)
def save_all_channels(channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(FILE_PATH, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = read_all_channels()
transport_path = new_channel.transport.get_path()
# If channel is modified: replace the old by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
save_all_channels(channels)
return
# Else: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
save_all_channels(channels)
def remove_channel(transport_path: str) -> None:
channels = read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
save_all_channels(remaining_channels)

View File

@ -33,11 +33,22 @@ class NewTrezorClient:
self.protocol = self._get_protocol() self.protocol = self._get_protocol()
else: else:
self.protocol = protocol self.protocol = protocol
self.protocol.mapping = self.mapping
@classmethod @classmethod
def resume( def resume(
cls, transport: NewTransport, channel_data: ChannelData cls,
) -> NewTrezorClient: ... transport: NewTransport,
channel_data: ChannelData,
protobuf_mapping: ProtobufMapping | None = None,
) -> NewTrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
if channel_data.protocol_version == 2:
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
else:
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
return NewTrezorClient(transport, protobuf_mapping, protocol)
def get_session( def get_session(
self, self,
@ -59,7 +70,6 @@ class NewTrezorClient:
self.management_session = SessionV1.new(self, "", False) self.management_session = SessionV1.new(self, "", False)
elif isinstance(self.protocol, ProtocolV2): elif isinstance(self.protocol, ProtocolV2):
self.management_session = SessionV2(self, b"\x00") self.management_session = SessionV2(self, b"\x00")
assert self.management_session is not None assert self.management_session is not None
return self.management_session return self.management_session

View File

@ -13,15 +13,16 @@ LOG = logging.getLogger(__name__)
class ProtocolAndChannel: class ProtocolAndChannel:
def __init__( def __init__(
self, self,
transport: NewTransport, transport: NewTransport,
mapping: ProtobufMapping, mapping: ProtobufMapping,
channel_keys: ChannelData | None = None, channel_data: ChannelData | None = None,
) -> None: ) -> None:
self.transport = transport self.transport = transport
self.mapping = mapping self.mapping = mapping
self.channel_keys = channel_keys self.channel_keys = channel_data
def close(self) -> None: ... def close(self) -> None: ...
@ -29,7 +30,8 @@ class ProtocolAndChannel:
# def read(self, session_id: bytes) -> t.Any: ... # def read(self, session_id: bytes) -> t.Any: ...
def get_channel_keys(self) -> ChannelData: ... def get_channel_data(self) -> ChannelData:
raise NotImplementedError
class ProtocolV1(ProtocolAndChannel): class ProtocolV1(ProtocolAndChannel):
@ -101,15 +103,3 @@ class ProtocolV1(ProtocolAndChannel):
if chunk[:1] != b"?": if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters") raise RuntimeError("Unexpected magic characters")
return chunk[1:] return chunk[1:]
class Channel:
id: int
channel_keys: ChannelData | None
def __init__(self, id: int, keys: ChannelData) -> None:
self.id = id
self.channel_keys = keys
def read(self) -> t.Any: ...
def write(self, msg: t.Any) -> None: ...

View File

@ -17,7 +17,7 @@ from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.packet_header import PacketHeader from ..thp.packet_header import PacketHeader
from . import control_byte from . import control_byte
from .channel_data import ChannelData from .channel_data import ChannelData
from .protocol_and_channel import Channel, ProtocolAndChannel from .protocol_and_channel import ProtocolAndChannel
from .transport import NewTransport from .transport import NewTransport
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -47,31 +47,56 @@ def _get_iv_from_nonce(nonce: int) -> bytes:
class ProtocolV2(ProtocolAndChannel): class ProtocolV2(ProtocolAndChannel):
channel_id: int
key_request: bytes key_request: bytes
key_response: bytes key_response: bytes
nonce_request: int nonce_request: int
nonce_response: int nonce_response: int
channel_id: int
sync_bit_send: int sync_bit_send: int
sync_bit_receive: int sync_bit_receive: int
has_valid_channel: bool = False has_valid_channel: bool = False
has_valid_features: bool = False
features: messages.Features features: messages.Features
def __init__( def __init__(
self, self,
transport: NewTransport, transport: NewTransport,
mapping: ProtobufMapping, mapping: ProtobufMapping,
channel_keys: ChannelData | None = None, channel_data: ChannelData | None = None,
) -> None: ) -> None:
super().__init__(transport, mapping, channel_keys) super().__init__(transport, mapping, channel_data)
self.channel: Channel | None = None if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.has_valid_channel = True
def get_channel(self) -> ProtocolV2: def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel: if not self.has_valid_channel:
self._establish_new_channel() self._establish_new_channel()
if not self.has_valid_features:
self.update_features() self.update_features()
return self return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
)
def read(self, session_id: int) -> t.Any: def read(self, session_id: int) -> t.Any:
header, data = self._read_until_valid_crc_check() header, data = self._read_until_valid_crc_check()
# TODO # TODO
@ -83,7 +108,6 @@ class ProtocolV2(ProtocolAndChannel):
def update_features(self) -> None: def update_features(self) -> None:
message = messages.GetFeatures() message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message) message_type, message_data = self.mapping.encode(message)
self.session_id: int = 0 self.session_id: int = 0
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK _ = self._read_until_valid_crc_check() # TODO check ACK
@ -91,6 +115,7 @@ class ProtocolV2(ProtocolAndChannel):
features = self.mapping.decode(msg_type, msg_data) features = self.mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features) assert isinstance(features, messages.Features)
self.features = features self.features = features
self.has_valid_features = True
def _establish_new_channel(self) -> None: def _establish_new_channel(self) -> None:
self.sync_bit_send = 0 self.sync_bit_send = 0

View File

@ -70,8 +70,11 @@ class SessionV2(Session):
def __init__(self, client: NewTrezorClient, id: bytes) -> None: def __init__(self, client: NewTrezorClient, id: bytes) -> None:
super().__init__(client, id) super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2) assert isinstance(client.protocol, ProtocolV2)
self.channel: ProtocolV2 = client.protocol.get_channel() self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id) self.update_id_and_sid(id)
if not self.channel.has_valid_features:
self.channel.update_features()
self.features = self.channel.features self.features = self.channel.features
def call(self, msg: t.Any) -> t.Any: def call(self, msg: t.Any) -> t.Any: