1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-15 14:55:43 +00:00
This commit is contained in:
M1nd3r 2025-04-11 23:18:30 +02:00
parent bab685036b
commit a937756f14
9 changed files with 84 additions and 69 deletions

View File

@ -26,7 +26,7 @@ from .tools import parse_path
from .transport import Transport, get_transport
from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel
from .transport.thp.protocol_v2 import ProtocolV2Channel
from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState
if t.TYPE_CHECKING:
from .transport.session import Session, SessionV1
@ -62,6 +62,7 @@ class TrezorClient:
_setup_pin: str | None = None # Should be used only by conftest
_last_active_session: SessionV1 | None = None
_session_id_counter: int = 0
def __init__(
self,
transport: Transport,
@ -99,6 +100,26 @@ class TrezorClient:
else:
raise Exception("Unknown protocol version")
def do_pairing(self) -> None:
from .transport.session import SessionV2
assert self.protocol_version == ProtocolVersion.V2
session = SessionV2(client=self, id=b"\x00")
session.call(
messages.ThpPairingRequest(host_name="Trezorlib"),
expect=messages.ThpPairingRequestApproved,
skip_firmware_version_check=True,
)
session.call(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
),
expect=messages.ThpEndResponse,
skip_firmware_version_check=True,
)
assert isinstance(self.protocol, ProtocolV2Channel)
self.protocol._has_valid_channel = True
def get_session(
self,
passphrase: str | object = "",
@ -128,17 +149,21 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV2Channel):
from .transport.session import SessionV2
if self.protocol.trezor_state is TrezorState.UNPAIRED:
self.do_pairing()
if passphrase is SEEDLESS:
return SessionV2(self, id=b"\x00")
if self._session_id_counter >= 255:
self._session_id_counter = 0
assert isinstance(passphrase, str) or passphrase is None
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
if session_id is not None:
sid = int.from_bytes(session_id, "big")
else:
sid = 1
assert 0 <= sid <= 255
return SessionV2.new(self, passphrase, derive_cardano, sid)
self._session_id_counter += 1
return SessionV2.new(
self, passphrase, derive_cardano, self._session_id_counter
)
raise NotImplementedError
def get_seedless_session(self) -> Session:
@ -150,11 +175,20 @@ class TrezorClient:
@property
def features(self) -> messages.Features:
if self._features is None:
self._features = self.protocol.get_features()
self._features = self._get_features()
self.check_firmware_version(warn_only=True)
assert self._features is not None
return self._features
def _get_features(self) -> messages.Features:
if isinstance(self.protocol, ProtocolV2Channel):
if (
self.protocol.trezor_state is TrezorState.UNPAIRED
or not self.protocol._has_valid_channel
):
self.do_pairing()
return self.protocol.get_features()
@property
def protocol_version(self) -> int:
return self._protocol_version

View File

@ -1030,6 +1030,8 @@ class TrezorClientDebugLink(TrezorClient):
# without special DebugLink interface provided
# by the device.
protocol: ProtocolV1Channel | ProtocolV2Channel
def __init__(
self,
transport: Transport,
@ -1075,8 +1077,8 @@ class TrezorClientDebugLink(TrezorClient):
# and know the supported debug capabilities
if self.protocol_version is ProtocolVersion.V2:
assert isinstance(self.protocol, ProtocolV2Channel)
self.protocol._helper_debug = self.debug
self.protocol = self.protocol.get_channel()
self.do_pairing()
# self.protocol = self.protocol.get_channel()
self.debug.model = self.model
self.debug.version = self.version

View File

@ -26,9 +26,11 @@ class Session:
self,
msg: MessageType,
expect: type[MT] = MessageType,
skip_firmware_version_check: bool = False,
_passphrase_ack: messages.PassphraseAck | None = None,
) -> MT:
self.client.check_firmware_version()
if not skip_firmware_version_check:
self.client.check_firmware_version()
resp = self.call_raw(msg)
while True:
@ -260,14 +262,11 @@ class SessionV2(Session):
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
from ..debuglink import TrezorClientDebugLink
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2Channel)
if isinstance(client, TrezorClientDebugLink):
client.protocol._helper_debug = client.debug
self.channel: ProtocolV2Channel = client.protocol.get_channel()
self.channel: ProtocolV2Channel = client.protocol
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:

View File

@ -1,14 +1,7 @@
from __future__ import annotations
import logging
import typing as t
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
LOG = logging.getLogger(__name__)
class Channel:
@ -25,9 +18,3 @@ class Channel:
def update_features(self) -> None:
raise NotImplementedError
def read(self, timeout: float | None = None) -> t.Any:
raise NotImplementedError
def write(self, msg: t.Any) -> None:
raise NotImplementedError

View File

@ -4,6 +4,7 @@ import logging
import os
import typing as t
from binascii import hexlify
from enum import IntEnum
from noise.connection import Keypair, NoiseConnection
@ -21,10 +22,15 @@ LOG = logging.getLogger(__name__)
DEFAULT_SESSION_ID: int = 0
if t.TYPE_CHECKING:
from ...debuglink import DebugLink
pass
MT = t.TypeVar("MT", bound=protobuf.MessageType)
class TrezorState(IntEnum):
UNPAIRED = 0x00
PAIRED = 0x01
class ProtocolV2Channel(Channel):
channel_id: int
sync_bit_send: int
@ -33,18 +39,20 @@ class ProtocolV2Channel(Channel):
_has_valid_channel: bool = False
_features: messages.Features | None = None
_helper_debug: DebugLink | None = None
trezor_state: int = TrezorState.UNPAIRED
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
credential: bytes | None = None,
) -> None:
super().__init__(transport, mapping)
self.trezor_state = self.prepare_channel_without_pairing(credential=credential)
def get_channel(self) -> ProtocolV2Channel:
if not self._has_valid_channel:
self._establish_new_channel(self._helper_debug)
raise RuntimeError("Channel is invalidated")
return self
def read(self, session_id: int) -> t.Any:
@ -61,7 +69,7 @@ class ProtocolV2Channel(Channel):
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel(self._helper_debug)
raise RuntimeError("Channel is invalidated")
if self._features is None:
self.update_features()
assert self._features is not None
@ -96,11 +104,10 @@ class ProtocolV2Channel(Channel):
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
def prepare_channel_without_pairing(self, credential: bytes | None = None) -> int:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
return self._do_handshake(credential=credential)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0
@ -148,7 +155,7 @@ class ProtocolV2Channel(Channel):
self,
credential: bytes | None = None,
host_static_randomness: bytes | None = None,
):
) -> int:
randomness_static = host_static_randomness or os.urandom(32)
@ -160,7 +167,7 @@ class ProtocolV2Channel(Channel):
credential,
)
self._read_ack()
self._read_handshake_completion_response()
return self._read_handshake_completion_response()
def _send_handshake_init_request(self) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
@ -215,7 +222,7 @@ class ProtocolV2Channel(Channel):
)
self.handshake_hash = self._noise.get_handshake_hash()
def _read_handshake_completion_response(self) -> None:
def _read_handshake_completion_response(self) -> int:
# Read handshake completion response, ignore payload as we do not care about the state
header, data = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
@ -228,25 +235,7 @@ class ProtocolV2Channel(Channel):
print("trezor state:", trezor_state)
assert trezor_state == b"\x00" or trezor_state == b"\x01"
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
self._send_message(messages.ThpPairingRequest(host_name="Trezorlib"))
self._read_message(messages.ButtonRequest)
self._send_message(messages.ButtonAck())
if helper_debug is not None:
helper_debug.press_yes()
self._read_message(messages.ThpPairingRequestApproved)
self._send_message(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
)
)
self._read_message(messages.ThpEndResponse)
self._has_valid_channel = True
return int.from_bytes(trezor_state, "big")
def _read_ack(self):
header, payload = self._read_until_valid_crc_check()

View File

@ -27,11 +27,11 @@ def prepare_protocol_for_pairing(
def get_encrypted_transport_protocol(
client: Client, host_static_randomness: bytes | None = None
) -> ProtocolV2Channel:
protocol = prepare_protocol_for_pairing(
client.protocol = prepare_protocol_for_pairing(
client, host_static_randomness=host_static_randomness
)
protocol._do_pairing(client.debug)
return protocol
client.do_pairing()
return client.protocol
def handle_pairing_request(

View File

@ -47,4 +47,5 @@ def test_handshake(client: Client) -> None:
# TODO - without pairing, the client is damaged and results in fail of the following test
# so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error
protocol._do_pairing(client.debug)
client.protocol = protocol
client.do_pairing()

View File

@ -50,10 +50,12 @@ def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Cha
)
protocol_1._do_handshake()
protocol_1._do_pairing(client.debug)
client.protocol = protocol_1
client.do_pairing()
sleep(LOCK_TIME)
protocol_2._do_handshake()
protocol_2._do_pairing(client.debug)
client.protocol = protocol_2
client.do_pairing()
return protocol_1, protocol_2
@ -122,7 +124,8 @@ def test_concurrent_handshakes_1(client: Client) -> None:
# The second host performs action that results
# in the invalidation of the first host's handshake state
protocol_2._do_pairing(helper_debug=client.debug)
client.protocol = protocol_2
client.do_pairing()
# Even after LOCK_TIME passes, the first host's channel cannot
# be resumed

View File

@ -304,16 +304,16 @@ def test_channel_replacement(client: Client) -> None:
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
session = client.get_session(passphrase="TREZOR", session_id=b"\x10")
session = client.get_session(passphrase="TREZOR")
address = get_test_address(session)
session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20")
session_2 = client.get_session(passphrase="ROZERT")
address_2 = get_test_address(session_2)
assert address != address_2
# create new channel using the same host_static_privkey
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30")
session_3 = client.get_session(passphrase="OKIDOKI")
address_3 = get_test_address(session_3)
assert address_3 != address_2
@ -333,6 +333,6 @@ def test_channel_replacement(client: Client) -> None:
_ = get_test_address(session_3)
assert str(e_2.value.message) == "Invalid session"
session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40")
session_4 = client.get_session(passphrase="TREZOR")
super_new_address = get_test_address(session_4)
assert address == super_new_address