mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-05 16:39:08 +00:00
wip
This commit is contained in:
parent
cdb2521912
commit
d05e10abbd
@ -26,7 +26,7 @@ from .tools import parse_path
|
|||||||
from .transport import Transport, get_transport
|
from .transport import Transport, get_transport
|
||||||
from .transport.thp.protocol_and_channel import Channel
|
from .transport.thp.protocol_and_channel import Channel
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
from .transport.thp.protocol_v2 import ProtocolV2Channel
|
from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from .transport.session import Session, SessionV1
|
from .transport.session import Session, SessionV1
|
||||||
@ -62,6 +62,7 @@ class TrezorClient:
|
|||||||
_setup_pin: str | None = None # Should be used only by conftest
|
_setup_pin: str | None = None # Should be used only by conftest
|
||||||
_last_active_session: SessionV1 | None = None
|
_last_active_session: SessionV1 | None = None
|
||||||
|
|
||||||
|
_session_id_counter: int = 0
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
@ -99,6 +100,26 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
raise Exception("Unknown protocol version")
|
raise Exception("Unknown protocol version")
|
||||||
|
|
||||||
|
def do_pairing(self) -> None:
|
||||||
|
from .transport.session import SessionV2
|
||||||
|
|
||||||
|
assert self.protocol_version == ProtocolVersion.V2
|
||||||
|
session = SessionV2(client=self, id=b"\x00")
|
||||||
|
session.call(
|
||||||
|
messages.ThpPairingRequest(host_name="Trezorlib"),
|
||||||
|
expect=messages.ThpPairingRequestApproved,
|
||||||
|
skip_firmware_version_check=True,
|
||||||
|
)
|
||||||
|
session.call(
|
||||||
|
messages.ThpSelectMethod(
|
||||||
|
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||||
|
),
|
||||||
|
expect=messages.ThpEndResponse,
|
||||||
|
skip_firmware_version_check=True,
|
||||||
|
)
|
||||||
|
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||||
|
self.protocol._has_valid_channel = True
|
||||||
|
|
||||||
def get_session(
|
def get_session(
|
||||||
self,
|
self,
|
||||||
passphrase: str | object = "",
|
passphrase: str | object = "",
|
||||||
@ -128,17 +149,21 @@ class TrezorClient:
|
|||||||
if isinstance(self.protocol, ProtocolV2Channel):
|
if isinstance(self.protocol, ProtocolV2Channel):
|
||||||
from .transport.session import SessionV2
|
from .transport.session import SessionV2
|
||||||
|
|
||||||
|
if self.protocol.trezor_state is TrezorState.UNPAIRED:
|
||||||
|
self.do_pairing()
|
||||||
|
|
||||||
if passphrase is SEEDLESS:
|
if passphrase is SEEDLESS:
|
||||||
return SessionV2(self, id=b"\x00")
|
return SessionV2(self, id=b"\x00")
|
||||||
|
|
||||||
|
if self._session_id_counter >= 255:
|
||||||
|
self._session_id_counter = 0
|
||||||
assert isinstance(passphrase, str) or passphrase is None
|
assert isinstance(passphrase, str) or passphrase is None
|
||||||
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
|
self._session_id_counter += 1
|
||||||
if session_id is not None:
|
|
||||||
sid = int.from_bytes(session_id, "big")
|
return SessionV2.new(
|
||||||
else:
|
self, passphrase, derive_cardano, self._session_id_counter
|
||||||
sid = 1
|
)
|
||||||
assert 0 <= sid <= 255
|
|
||||||
return SessionV2.new(self, passphrase, derive_cardano, sid)
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_seedless_session(self) -> Session:
|
def get_seedless_session(self) -> Session:
|
||||||
@ -150,11 +175,20 @@ class TrezorClient:
|
|||||||
@property
|
@property
|
||||||
def features(self) -> messages.Features:
|
def features(self) -> messages.Features:
|
||||||
if self._features is None:
|
if self._features is None:
|
||||||
self._features = self.protocol.get_features()
|
self._features = self._get_features()
|
||||||
self.check_firmware_version(warn_only=True)
|
self.check_firmware_version(warn_only=True)
|
||||||
assert self._features is not None
|
assert self._features is not None
|
||||||
return self._features
|
return self._features
|
||||||
|
|
||||||
|
def _get_features(self) -> messages.Features:
|
||||||
|
if isinstance(self.protocol, ProtocolV2Channel):
|
||||||
|
if (
|
||||||
|
self.protocol.trezor_state is TrezorState.UNPAIRED
|
||||||
|
or not self.protocol._has_valid_channel
|
||||||
|
):
|
||||||
|
self.do_pairing()
|
||||||
|
return self.protocol.get_features()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def protocol_version(self) -> int:
|
def protocol_version(self) -> int:
|
||||||
return self._protocol_version
|
return self._protocol_version
|
||||||
|
@ -1030,6 +1030,8 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# without special DebugLink interface provided
|
# without special DebugLink interface provided
|
||||||
# by the device.
|
# by the device.
|
||||||
|
|
||||||
|
protocol: ProtocolV1Channel | ProtocolV2Channel
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
@ -1075,8 +1077,8 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# and know the supported debug capabilities
|
# and know the supported debug capabilities
|
||||||
if self.protocol_version is ProtocolVersion.V2:
|
if self.protocol_version is ProtocolVersion.V2:
|
||||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||||
self.protocol._helper_debug = self.debug
|
self.do_pairing()
|
||||||
self.protocol = self.protocol.get_channel()
|
# self.protocol = self.protocol.get_channel()
|
||||||
self.debug.model = self.model
|
self.debug.model = self.model
|
||||||
self.debug.version = self.version
|
self.debug.version = self.version
|
||||||
|
|
||||||
|
@ -26,8 +26,10 @@ class Session:
|
|||||||
self,
|
self,
|
||||||
msg: MessageType,
|
msg: MessageType,
|
||||||
expect: type[MT] = MessageType,
|
expect: type[MT] = MessageType,
|
||||||
|
skip_firmware_version_check: bool = False,
|
||||||
_passphrase_ack: messages.PassphraseAck | None = None,
|
_passphrase_ack: messages.PassphraseAck | None = None,
|
||||||
) -> MT:
|
) -> MT:
|
||||||
|
if not skip_firmware_version_check:
|
||||||
self.client.check_firmware_version()
|
self.client.check_firmware_version()
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
|
|
||||||
@ -260,14 +262,11 @@ class SessionV2(Session):
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||||
from ..debuglink import TrezorClientDebugLink
|
|
||||||
|
|
||||||
super().__init__(client, id)
|
super().__init__(client, id)
|
||||||
assert isinstance(client.protocol, ProtocolV2Channel)
|
assert isinstance(client.protocol, ProtocolV2Channel)
|
||||||
|
|
||||||
if isinstance(client, TrezorClientDebugLink):
|
self.channel: ProtocolV2Channel = client.protocol
|
||||||
client.protocol._helper_debug = client.debug
|
|
||||||
self.channel: ProtocolV2Channel = client.protocol.get_channel()
|
|
||||||
self.update_id_and_sid(id)
|
self.update_id_and_sid(id)
|
||||||
|
|
||||||
def _write(self, msg: t.Any) -> None:
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
@ -1,14 +1,7 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import typing as t
|
|
||||||
|
|
||||||
from ... import messages
|
from ... import messages
|
||||||
from ...mapping import ProtobufMapping
|
from ...mapping import ProtobufMapping
|
||||||
from .. import Transport
|
from .. import Transport
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
|
|
||||||
@ -25,9 +18,3 @@ class Channel:
|
|||||||
|
|
||||||
def update_features(self) -> None:
|
def update_features(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def read(self, timeout: float | None = None) -> t.Any:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def write(self, msg: t.Any) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import typing as t
|
import typing as t
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
from noise.connection import Keypair, NoiseConnection
|
from noise.connection import Keypair, NoiseConnection
|
||||||
|
|
||||||
@ -21,10 +22,15 @@ LOG = logging.getLogger(__name__)
|
|||||||
DEFAULT_SESSION_ID: int = 0
|
DEFAULT_SESSION_ID: int = 0
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ...debuglink import DebugLink
|
pass
|
||||||
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
class TrezorState(IntEnum):
|
||||||
|
UNPAIRED = 0x00
|
||||||
|
PAIRED = 0x01
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV2Channel(Channel):
|
class ProtocolV2Channel(Channel):
|
||||||
channel_id: int
|
channel_id: int
|
||||||
sync_bit_send: int
|
sync_bit_send: int
|
||||||
@ -33,18 +39,20 @@ class ProtocolV2Channel(Channel):
|
|||||||
|
|
||||||
_has_valid_channel: bool = False
|
_has_valid_channel: bool = False
|
||||||
_features: messages.Features | None = None
|
_features: messages.Features | None = None
|
||||||
_helper_debug: DebugLink | None = None
|
trezor_state: int = TrezorState.UNPAIRED
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
mapping: ProtobufMapping,
|
mapping: ProtobufMapping,
|
||||||
|
credential: bytes | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(transport, mapping)
|
super().__init__(transport, mapping)
|
||||||
|
self.trezor_state = self.prepare_channel_without_pairing(credential=credential)
|
||||||
|
|
||||||
def get_channel(self) -> ProtocolV2Channel:
|
def get_channel(self) -> ProtocolV2Channel:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
self._establish_new_channel(self._helper_debug)
|
raise RuntimeError("Channel is invalidated")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def read(self, session_id: int) -> t.Any:
|
def read(self, session_id: int) -> t.Any:
|
||||||
@ -61,7 +69,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
|
|
||||||
def get_features(self) -> messages.Features:
|
def get_features(self) -> messages.Features:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
self._establish_new_channel(self._helper_debug)
|
raise RuntimeError("Channel is invalidated")
|
||||||
if self._features is None:
|
if self._features is None:
|
||||||
self.update_features()
|
self.update_features()
|
||||||
assert self._features is not None
|
assert self._features is not None
|
||||||
@ -96,11 +104,10 @@ class ProtocolV2Channel(Channel):
|
|||||||
assert isinstance(msg, message_type)
|
assert isinstance(msg, message_type)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
|
def prepare_channel_without_pairing(self, credential: bytes | None = None) -> int:
|
||||||
self._reset_sync_bits()
|
self._reset_sync_bits()
|
||||||
self._do_channel_allocation()
|
self._do_channel_allocation()
|
||||||
self._do_handshake()
|
return self._do_handshake(credential=credential)
|
||||||
self._do_pairing(helper_debug)
|
|
||||||
|
|
||||||
def _reset_sync_bits(self) -> None:
|
def _reset_sync_bits(self) -> None:
|
||||||
self.sync_bit_send = 0
|
self.sync_bit_send = 0
|
||||||
@ -148,7 +155,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
self,
|
self,
|
||||||
credential: bytes | None = None,
|
credential: bytes | None = None,
|
||||||
host_static_randomness: bytes | None = None,
|
host_static_randomness: bytes | None = None,
|
||||||
):
|
) -> int:
|
||||||
|
|
||||||
randomness_static = host_static_randomness or os.urandom(32)
|
randomness_static = host_static_randomness or os.urandom(32)
|
||||||
|
|
||||||
@ -160,7 +167,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
credential,
|
credential,
|
||||||
)
|
)
|
||||||
self._read_ack()
|
self._read_ack()
|
||||||
self._read_handshake_completion_response()
|
return self._read_handshake_completion_response()
|
||||||
|
|
||||||
def _send_handshake_init_request(self) -> None:
|
def _send_handshake_init_request(self) -> None:
|
||||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||||
@ -215,7 +222,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
)
|
)
|
||||||
self.handshake_hash = self._noise.get_handshake_hash()
|
self.handshake_hash = self._noise.get_handshake_hash()
|
||||||
|
|
||||||
def _read_handshake_completion_response(self) -> None:
|
def _read_handshake_completion_response(self) -> int:
|
||||||
# Read handshake completion response, ignore payload as we do not care about the state
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
header, data = self._read_until_valid_crc_check()
|
header, data = self._read_until_valid_crc_check()
|
||||||
if not header.is_handshake_comp_response():
|
if not header.is_handshake_comp_response():
|
||||||
@ -228,25 +235,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
print("trezor state:", trezor_state)
|
print("trezor state:", trezor_state)
|
||||||
assert trezor_state == b"\x00" or trezor_state == b"\x01"
|
assert trezor_state == b"\x00" or trezor_state == b"\x01"
|
||||||
self._send_ack_1()
|
self._send_ack_1()
|
||||||
|
return int.from_bytes(trezor_state, "big")
|
||||||
def _do_pairing(self, helper_debug: DebugLink | None):
|
|
||||||
|
|
||||||
self._send_message(messages.ThpPairingRequest(host_name="Trezorlib"))
|
|
||||||
self._read_message(messages.ButtonRequest)
|
|
||||||
self._send_message(messages.ButtonAck())
|
|
||||||
|
|
||||||
if helper_debug is not None:
|
|
||||||
helper_debug.press_yes()
|
|
||||||
|
|
||||||
self._read_message(messages.ThpPairingRequestApproved)
|
|
||||||
self._send_message(
|
|
||||||
messages.ThpSelectMethod(
|
|
||||||
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._read_message(messages.ThpEndResponse)
|
|
||||||
|
|
||||||
self._has_valid_channel = True
|
|
||||||
|
|
||||||
def _read_ack(self):
|
def _read_ack(self):
|
||||||
header, payload = self._read_until_valid_crc_check()
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
@ -27,11 +27,11 @@ def prepare_protocol_for_pairing(
|
|||||||
def get_encrypted_transport_protocol(
|
def get_encrypted_transport_protocol(
|
||||||
client: Client, host_static_randomness: bytes | None = None
|
client: Client, host_static_randomness: bytes | None = None
|
||||||
) -> ProtocolV2Channel:
|
) -> ProtocolV2Channel:
|
||||||
protocol = prepare_protocol_for_pairing(
|
client.protocol = prepare_protocol_for_pairing(
|
||||||
client, host_static_randomness=host_static_randomness
|
client, host_static_randomness=host_static_randomness
|
||||||
)
|
)
|
||||||
protocol._do_pairing(client.debug)
|
client.do_pairing()
|
||||||
return protocol
|
return client.protocol
|
||||||
|
|
||||||
|
|
||||||
def handle_pairing_request(
|
def handle_pairing_request(
|
||||||
|
@ -47,4 +47,5 @@ def test_handshake(client: Client) -> None:
|
|||||||
|
|
||||||
# TODO - without pairing, the client is damaged and results in fail of the following test
|
# TODO - without pairing, the client is damaged and results in fail of the following test
|
||||||
# so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error
|
# so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error
|
||||||
protocol._do_pairing(client.debug)
|
client.protocol = protocol
|
||||||
|
client.do_pairing()
|
||||||
|
@ -50,10 +50,12 @@ def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Cha
|
|||||||
)
|
)
|
||||||
protocol_1._do_handshake()
|
protocol_1._do_handshake()
|
||||||
|
|
||||||
protocol_1._do_pairing(client.debug)
|
client.protocol = protocol_1
|
||||||
|
client.do_pairing()
|
||||||
sleep(LOCK_TIME)
|
sleep(LOCK_TIME)
|
||||||
protocol_2._do_handshake()
|
protocol_2._do_handshake()
|
||||||
protocol_2._do_pairing(client.debug)
|
client.protocol = protocol_2
|
||||||
|
client.do_pairing()
|
||||||
|
|
||||||
return protocol_1, protocol_2
|
return protocol_1, protocol_2
|
||||||
|
|
||||||
@ -122,7 +124,8 @@ def test_concurrent_handshakes_1(client: Client) -> None:
|
|||||||
|
|
||||||
# The second host performs action that results
|
# The second host performs action that results
|
||||||
# in the invalidation of the first host's handshake state
|
# in the invalidation of the first host's handshake state
|
||||||
protocol_2._do_pairing(helper_debug=client.debug)
|
client.protocol = protocol_2
|
||||||
|
client.do_pairing()
|
||||||
|
|
||||||
# Even after LOCK_TIME passes, the first host's channel cannot
|
# Even after LOCK_TIME passes, the first host's channel cannot
|
||||||
# be resumed
|
# be resumed
|
||||||
|
@ -304,16 +304,16 @@ def test_channel_replacement(client: Client) -> None:
|
|||||||
|
|
||||||
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
||||||
|
|
||||||
session = client.get_session(passphrase="TREZOR", session_id=b"\x10")
|
session = client.get_session(passphrase="TREZOR")
|
||||||
address = get_test_address(session)
|
address = get_test_address(session)
|
||||||
|
|
||||||
session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20")
|
session_2 = client.get_session(passphrase="ROZERT")
|
||||||
address_2 = get_test_address(session_2)
|
address_2 = get_test_address(session_2)
|
||||||
assert address != address_2
|
assert address != address_2
|
||||||
|
|
||||||
# create new channel using the same host_static_privkey
|
# create new channel using the same host_static_privkey
|
||||||
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
client.protocol = get_encrypted_transport_protocol(client, host_static_randomness)
|
||||||
session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30")
|
session_3 = client.get_session(passphrase="OKIDOKI")
|
||||||
address_3 = get_test_address(session_3)
|
address_3 = get_test_address(session_3)
|
||||||
assert address_3 != address_2
|
assert address_3 != address_2
|
||||||
|
|
||||||
@ -333,6 +333,6 @@ def test_channel_replacement(client: Client) -> None:
|
|||||||
_ = get_test_address(session_3)
|
_ = get_test_address(session_3)
|
||||||
assert str(e_2.value.message) == "Invalid session"
|
assert str(e_2.value.message) == "Invalid session"
|
||||||
|
|
||||||
session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40")
|
session_4 = client.get_session(passphrase="TREZOR")
|
||||||
super_new_address = get_test_address(session_4)
|
super_new_address = get_test_address(session_4)
|
||||||
assert address == super_new_address
|
assert address == super_new_address
|
||||||
|
Loading…
Reference in New Issue
Block a user