mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-15 14:55:43 +00:00
wip
This commit is contained in:
parent
bab685036b
commit
a937756f14
@ -26,7 +26,7 @@ from .tools import parse_path
|
||||
from .transport import Transport, get_transport
|
||||
from .transport.thp.protocol_and_channel import Channel
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
from .transport.thp.protocol_v2 import ProtocolV2Channel
|
||||
from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .transport.session import Session, SessionV1
|
||||
@ -62,6 +62,7 @@ class TrezorClient:
|
||||
_setup_pin: str | None = None # Should be used only by conftest
|
||||
_last_active_session: SessionV1 | None = None
|
||||
|
||||
_session_id_counter: int = 0
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
@ -99,6 +100,26 @@ class TrezorClient:
|
||||
else:
|
||||
raise Exception("Unknown protocol version")
|
||||
|
||||
def do_pairing(self) -> None:
|
||||
from .transport.session import SessionV2
|
||||
|
||||
assert self.protocol_version == ProtocolVersion.V2
|
||||
session = SessionV2(client=self, id=b"\x00")
|
||||
session.call(
|
||||
messages.ThpPairingRequest(host_name="Trezorlib"),
|
||||
expect=messages.ThpPairingRequestApproved,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
session.call(
|
||||
messages.ThpSelectMethod(
|
||||
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||
),
|
||||
expect=messages.ThpEndResponse,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||
self.protocol._has_valid_channel = True
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object = "",
|
||||
@ -128,17 +149,21 @@ class TrezorClient:
|
||||
if isinstance(self.protocol, ProtocolV2Channel):
|
||||
from .transport.session import SessionV2
|
||||
|
||||
if self.protocol.trezor_state is TrezorState.UNPAIRED:
|
||||
self.do_pairing()
|
||||
|
||||
if passphrase is SEEDLESS:
|
||||
return SessionV2(self, id=b"\x00")
|
||||
|
||||
if self._session_id_counter >= 255:
|
||||
self._session_id_counter = 0
|
||||
assert isinstance(passphrase, str) or passphrase is None
|
||||
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
|
||||
if session_id is not None:
|
||||
sid = int.from_bytes(session_id, "big")
|
||||
else:
|
||||
sid = 1
|
||||
assert 0 <= sid <= 255
|
||||
return SessionV2.new(self, passphrase, derive_cardano, sid)
|
||||
self._session_id_counter += 1
|
||||
|
||||
return SessionV2.new(
|
||||
self, passphrase, derive_cardano, self._session_id_counter
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def get_seedless_session(self) -> Session:
|
||||
@ -150,11 +175,20 @@ class TrezorClient:
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self._features = self.protocol.get_features()
|
||||
self._features = self._get_features()
|
||||
self.check_firmware_version(warn_only=True)
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def _get_features(self) -> messages.Features:
|
||||
if isinstance(self.protocol, ProtocolV2Channel):
|
||||
if (
|
||||
self.protocol.trezor_state is TrezorState.UNPAIRED
|
||||
or not self.protocol._has_valid_channel
|
||||
):
|
||||
self.do_pairing()
|
||||
return self.protocol.get_features()
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self._protocol_version
|
||||
|
@ -1030,6 +1030,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# without special DebugLink interface provided
|
||||
# by the device.
|
||||
|
||||
protocol: ProtocolV1Channel | ProtocolV2Channel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
@ -1075,8 +1077,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# and know the supported debug capabilities
|
||||
if self.protocol_version is ProtocolVersion.V2:
|
||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||
self.protocol._helper_debug = self.debug
|
||||
self.protocol = self.protocol.get_channel()
|
||||
self.do_pairing()
|
||||
# self.protocol = self.protocol.get_channel()
|
||||
self.debug.model = self.model
|
||||
self.debug.version = self.version
|
||||
|
||||
|
@ -26,9 +26,11 @@ class Session:
|
||||
self,
|
||||
msg: MessageType,
|
||||
expect: type[MT] = MessageType,
|
||||
skip_firmware_version_check: bool = False,
|
||||
_passphrase_ack: messages.PassphraseAck | None = None,
|
||||
) -> MT:
|
||||
self.client.check_firmware_version()
|
||||
if not skip_firmware_version_check:
|
||||
self.client.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
|
||||
while True:
|
||||
@ -260,14 +262,11 @@ class SessionV2(Session):
|
||||
return session
|
||||
|
||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2Channel)
|
||||
|
||||
if isinstance(client, TrezorClientDebugLink):
|
||||
client.protocol._helper_debug = client.debug
|
||||
self.channel: ProtocolV2Channel = client.protocol.get_channel()
|
||||
self.channel: ProtocolV2Channel = client.protocol
|
||||
self.update_id_and_sid(id)
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
|
@ -1,14 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Channel:
|
||||
|
||||
@ -25,9 +18,3 @@ class Channel:
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self, timeout: float | None = None) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import typing as t
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
|
||||
from noise.connection import Keypair, NoiseConnection
|
||||
|
||||
@ -21,10 +22,15 @@ LOG = logging.getLogger(__name__)
|
||||
DEFAULT_SESSION_ID: int = 0
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ...debuglink import DebugLink
|
||||
pass
|
||||
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class TrezorState(IntEnum):
|
||||
UNPAIRED = 0x00
|
||||
PAIRED = 0x01
|
||||
|
||||
|
||||
class ProtocolV2Channel(Channel):
|
||||
channel_id: int
|
||||
sync_bit_send: int
|
||||
@ -33,18 +39,20 @@ class ProtocolV2Channel(Channel):
|
||||
|
||||
_has_valid_channel: bool = False
|
||||
_features: messages.Features | None = None
|
||||
_helper_debug: DebugLink | None = None
|
||||
trezor_state: int = TrezorState.UNPAIRED
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
credential: bytes | None = None,
|
||||
) -> None:
|
||||
super().__init__(transport, mapping)
|
||||
self.trezor_state = self.prepare_channel_without_pairing(credential=credential)
|
||||
|
||||
def get_channel(self) -> ProtocolV2Channel:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel(self._helper_debug)
|
||||
raise RuntimeError("Channel is invalidated")
|
||||
return self
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
@ -61,7 +69,7 @@ class ProtocolV2Channel(Channel):
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel(self._helper_debug)
|
||||
raise RuntimeError("Channel is invalidated")
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
@ -96,11 +104,10 @@ class ProtocolV2Channel(Channel):
|
||||
assert isinstance(msg, message_type)
|
||||
return msg
|
||||
|
||||
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
|
||||
def prepare_channel_without_pairing(self, credential: bytes | None = None) -> int:
|
||||
self._reset_sync_bits()
|
||||
self._do_channel_allocation()
|
||||
self._do_handshake()
|
||||
self._do_pairing(helper_debug)
|
||||
return self._do_handshake(credential=credential)
|
||||
|
||||
def _reset_sync_bits(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
@ -148,7 +155,7 @@ class ProtocolV2Channel(Channel):
|
||||
self,
|
||||
credential: bytes | None = None,
|
||||
host_static_randomness: bytes | None = None,
|
||||
):
|
||||
) -> int:
|
||||
|
||||
randomness_static = host_static_randomness or os.urandom(32)
|
||||
|
||||
@ -160,7 +167,7 @@ class ProtocolV2Channel(Channel):
|
||||
credential,
|
||||
)
|
||||
self._read_ack()
|
||||
self._read_handshake_completion_response()
|
||||
return self._read_handshake_completion_response()
|
||||
|
||||
def _send_handshake_init_request(self) -> None:
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
@ -215,7 +222,7 @@ class ProtocolV2Channel(Channel):
|
||||
)
|
||||
self.handshake_hash = self._noise.get_handshake_hash()
|
||||
|
||||
def _read_handshake_completion_response(self) -> None:
|
||||
def _read_handshake_completion_response(self) -> int:
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, data = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
@ -228,25 +235,7 @@ class ProtocolV2Channel(Channel):
|
||||
print("trezor state:", trezor_state)
|
||||
assert trezor_state == b"\x00" or trezor_state == b"\x01"
|
||||
self._send_ack_1()
|
||||
|
||||
def _do_pairing(self, helper_debug: DebugLink | None):
|
||||
|
||||
self._send_message(messages.ThpPairingRequest(host_name="Trezorlib"))
|
||||
self._read_message(messages.ButtonRequest)
|
||||
self._send_message(messages.ButtonAck())
|
||||
|
||||
if helper_debug is not None:
|
||||
helper_debug.press_yes()
|
||||
|
||||
self._read_message(messages.ThpPairingRequestApproved)
|
||||
self._send_message(
|
||||
messages.ThpSelectMethod(
|
||||
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||
)
|
||||
)
|
||||
self._read_message(messages.ThpEndResponse)
|
||||
|
||||
self._has_valid_channel = True
|
||||
return int.from_bytes(trezor_state, "big")
|
||||
|
||||
def _read_ack(self):
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
|
@ -27,11 +27,11 @@ def prepare_protocol_for_pairing(
|
||||
def get_encrypted_transport_protocol(
|
||||
client: Client, host_static_randomness: bytes | None = None
|
||||
) -> ProtocolV2Channel:
|
||||
protocol = prepare_protocol_for_pairing(
|
||||
client.protocol = prepare_protocol_for_pairing(
|
||||
client, host_static_randomness=host_static_randomness
|
||||
)
|
||||
protocol._do_pairing(client.debug)
|
||||
return protocol
|
||||
client.do_pairing()
|
||||
return client.protocol
|
||||
|
||||
|
||||
def handle_pairing_request(
|
||||
|
@ -47,4 +47,5 @@ def test_handshake(client: Client) -> None:
|
||||
|
||||
# TODO - without pairing, the client is damaged and results in fail of the following test
|
||||
# so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error
|
||||
protocol._do_pairing(client.debug)
|
||||
client.protocol = protocol
|
||||
client.do_pairing()
|
||||
|
@ -50,10 +50,12 @@ def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Cha
|
||||
)
|
||||
protocol_1._do_handshake()
|
||||
|
||||
protocol_1._do_pairing(client.debug)
|
||||
client.protocol = protocol_1
|
||||
client.do_pairing()
|
||||
sleep(LOCK_TIME)
|
||||
protocol_2._do_handshake()
|
||||
protocol_2._do_pairing(client.debug)
|
||||
client.protocol = protocol_2
|
||||
client.do_pairing()
|
||||
|
||||
return protocol_1, protocol_2
|
||||
|
||||
@ -122,7 +124,8 @@ def test_concurrent_handshakes_1(client: Client) -> None:
|
||||
|
||||
# The second host performs action that results
|
||||
# in the invalidation of the first host's handshake state
|
||||
protocol_2._do_pairing(helper_debug=client.debug)
|
||||
client.protocol = protocol_2
|
||||
client.do_pairing()
|
||||
|
||||
# Even after LOCK_TIME passes, the first host's channel cannot
|
||||
# be resumed
|
||||
|
@ -304,16 +304,16 @@ def test_channel_replacement(client: Client) -> None:
|
||||
|
||||
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
||||
|
||||
session = client.get_session(passphrase="TREZOR", session_id=b"\x10")
|
||||
session = client.get_session(passphrase="TREZOR")
|
||||
address = get_test_address(session)
|
||||
|
||||
session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20")
|
||||
session_2 = client.get_session(passphrase="ROZERT")
|
||||
address_2 = get_test_address(session_2)
|
||||
assert address != address_2
|
||||
|
||||
# create new channel using the same host_static_privkey
|
||||
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
||||
session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30")
|
||||
session_3 = client.get_session(passphrase="OKIDOKI")
|
||||
address_3 = get_test_address(session_3)
|
||||
assert address_3 != address_2
|
||||
|
||||
@ -333,6 +333,6 @@ def test_channel_replacement(client: Client) -> None:
|
||||
_ = get_test_address(session_3)
|
||||
assert str(e_2.value.message) == "Invalid session"
|
||||
|
||||
session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40")
|
||||
session_4 = client.get_session(passphrase="TREZOR")
|
||||
super_new_address = get_test_address(session_4)
|
||||
assert address == super_new_address
|
||||
|
Loading…
Reference in New Issue
Block a user