1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-21 00:08:46 +00:00

chore(core): adapt emu.py to the new trezorlib

[no changelog]
This commit is contained in:
M1nd3r 2025-02-04 14:55:54 +01:00
parent e45e72b1f9
commit bc89c4916b
23 changed files with 1808 additions and 28 deletions

View File

@ -2,7 +2,7 @@
import binascii import binascii
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate() devices = HidTransport.enumerate()
if len(devices) > 0: if len(devices) > 0:

View File

@ -9,3 +9,4 @@ construct>=2.9,!=2.10.55
typing_extensions>=4.7.1 typing_extensions>=4.7.1
construct-classes>=0.1.2 construct-classes>=0.1.2
cryptography>=41 cryptography>=41
platformdirs>=2

View File

@ -30,7 +30,7 @@ from .. import exceptions, transport, ui
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport import Transport from ..transport import Transport
from ..transport.session import Session, SessionV1 from ..transport.session import Session, SessionV1, SessionV2
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -135,10 +135,13 @@ class TrezorConnection:
# Try resume session from id # Try resume session from id
if self.session_id is not None: if self.session_id is not None:
if client.protocol_version is ProtocolVersion.V1: if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id( session = SessionV1.resume_from_id(
client=client, session_id=self.session_id client=client, session_id=self.session_id
) )
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else: else:
raise Exception("Unsupported client protocol", client.protocol_version) raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume: if must_resume:
@ -267,6 +270,11 @@ class TrezorConnection:
empty_passphrase=empty_passphrase, empty_passphrase=empty_passphrase,
must_resume=must_resume, must_resume=must_resume,
) )
except exceptions.DeviceLockedException:
click.echo(
"Device is locked, enter a pin on the device.",
err=True,
)
except transport.DeviceIsBusy: except transport.DeviceIsBusy:
click.echo("Device is in use by another process.") click.echo("Device is in use by another process.")
sys.exit(1) sys.exit(1)

View File

@ -26,6 +26,7 @@ from .tools import parse_path
from .transport import Transport, get_transport from .transport import Transport, get_transport
from .transport.thp.protocol_and_channel import Channel from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v1 import ProtocolV1Channel
from .transport.thp.protocol_v2 import ProtocolV2Channel
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from .transport.session import Session, SessionV1 from .transport.session import Session, SessionV1
@ -93,6 +94,8 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV1Channel): if isinstance(self.protocol, ProtocolV1Channel):
self._protocol_version = ProtocolVersion.V1 self._protocol_version = ProtocolVersion.V1
elif isinstance(self.protocol, ProtocolV2Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else: else:
raise Exception("Unknown protocol version") raise Exception("Unknown protocol version")
@ -121,8 +124,18 @@ class TrezorClient:
derive_cardano=derive_cardano, derive_cardano=derive_cardano,
) )
derive_seed(session, passphrase) derive_seed(session, passphrase)
return session return session
if isinstance(self.protocol, ProtocolV2Channel):
from .transport.session import SessionV2
assert isinstance(passphrase, str) or passphrase is None
session_id = 1 # TODO fix this with ProtocolV2 session rework
if session_id is not None:
sid = int.from_bytes(session_id, "big")
else:
sid = 1
assert 0 <= sid <= 255
return SessionV2.new(self, passphrase, derive_cardano, sid)
raise NotImplementedError raise NotImplementedError
def get_seedless_session(self) -> Session: def get_seedless_session(self) -> Session:
@ -174,6 +187,15 @@ class TrezorClient:
def _get_protocol(self) -> Channel: def _get_protocol(self) -> Channel:
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING) protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2Channel(self.transport, self.mapping)
return protocol return protocol
def is_outdated(self) -> bool: def is_outdated(self) -> bool:

View File

@ -32,13 +32,20 @@ from pathlib import Path
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import btc, mapping, messages, models, protobuf from . import btc, mapping, messages, models, protobuf
from .client import ProtocolVersion, TrezorClient from .client import (
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
ProtocolVersion,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import Capability, DebugWaitType
from .protobuf import MessageType
from .tools import parse_path from .tools import parse_path
from .transport import Timeout from .transport import Timeout
from .transport.session import Session from .transport.session import Session, SessionV1, derive_seed
from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -542,6 +549,25 @@ class DebugLink:
raise TrezorFailure(result) raise TrezorFailure(result)
return result return result
def pairing_info(
self,
thp_channel_id: bytes | None = None,
handshake_hash: bytes | None = None,
nfc_secret_host: bytes | None = None,
) -> messages.DebugLinkPairingInfo:
result = self._call(
messages.DebugLinkGetPairingInfo(
channel_id=thp_channel_id,
handshake_hash=handshake_hash,
nfc_secret_host=nfc_secret_host,
)
)
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
result = self._read()
if isinstance(result, messages.Failure):
raise TrezorFailure(result)
return result
def read_layout(self, wait: bool | None = None) -> LayoutContent: def read_layout(self, wait: bool | None = None) -> LayoutContent:
""" """
Force waiting for the layout by setting `wait=True`. Force not waiting by Force waiting for the layout by setting `wait=True`. Force not waiting by
@ -808,6 +834,7 @@ class DebugUI:
def __init__(self, debuglink: DebugLink) -> None: def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink self.debuglink = debuglink
self.pins: t.Iterator[str] | None = None
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
@ -1078,16 +1105,20 @@ class TrezorClientDebugLink(TrezorClient):
self.sync_responses() self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT) def __getattr__(self, name: str) -> t.Any:
# and know the supported debug capabilities return getattr(self._session, name)
self.debug.model = self.model
self.debug.version = self.version def __setattr__(self, name: str, value: t.Any) -> None:
if hasattr(self._session, name):
setattr(self._session, name, value)
else:
self.__dict__[name] = value
self.reset_debug_features() self.reset_debug_features()
@property @property
def layout_type(self) -> LayoutType: def protocol_version(self) -> int:
return self.debug.layout_type return self.client.protocol_version
def get_new_client(self) -> TrezorClientDebugLink: def get_new_client(self) -> TrezorClientDebugLink:
new_client = TrezorClientDebugLink( new_client = TrezorClientDebugLink(
@ -1289,8 +1320,10 @@ class TrezorClientDebugLink(TrezorClient):
actual_responses = self.actual_responses actual_responses = self.actual_responses
# grab a copy of the inputflow generator to raise an exception through it # grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.ui, DebugUI): if isinstance(self.client, TrezorClientDebugLink) and isinstance(
input_flow = self.ui.input_flow self.client.ui, DebugUI
):
input_flow = self.client.ui.input_flow
else: else:
input_flow = None input_flow = None

View File

@ -109,3 +109,7 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
"""Tried to derive seed on uninitialized device. """Tried to derive seed on uninitialized device.
To communicate with uninitialized device, use seedless session instead.""" To communicate with uninitialized device, use seedless session instead."""
class DeviceLockedException(TrezorException):
pass

View File

@ -70,6 +70,16 @@ class ProtobufMapping:
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue() return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
"""Serialize a Python protobuf class.
Returns the byte representation of the protobuf message.
"""
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class.""" """Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type] cls = self.type_to_class[msg_wire_type]
@ -85,8 +95,9 @@ class ProtobufMapping:
mapping = cls() mapping = cls()
message_types = getattr(module, "MessageType") message_types = getattr(module, "MessageType")
thp_message_types = getattr(module, "ThpMessageType")
for entry in message_types: for entry in (*message_types, *thp_message_types):
msg_class = getattr(module, entry.name, None) msg_class = getattr(module, entry.name, None)
if msg_class is None: if msg_class is None:
raise ValueError( raise ValueError(

View File

@ -43,6 +43,10 @@ class FailureType(IntEnum):
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
BufferError = 17
DeviceIsBusy = 18
FirmwareError = 99 FirmwareError = 99
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
Pass = 2 Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpPairingRequest = 1006
ThpPairingRequestApproved = 1007
ThpSelectMethod = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceTrezor = 1018
ThpCodeEntryCpaceHostTag = 1019
ThpCodeEntrySecret = 1020
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcTagHost = 1032
ThpNfcTagTrezor = 1033
class ThpPairingMethod(IntEnum):
SkipPairing = 1
CodeEntry = 2
QrCode = 3
NFC = 4
class MessageType(IntEnum): class MessageType(IntEnum):
Initialize = 0 Initialize = 0
Ping = 1 Ping = 1
@ -501,6 +533,8 @@ class MessageType(IntEnum):
DebugLinkWatchLayout = 9006 DebugLinkWatchLayout = 9006
DebugLinkResetDebugEvents = 9007 DebugLinkResetDebugEvents = 9007
DebugLinkOptigaSetSecMax = 9008 DebugLinkOptigaSetSecMax = 9008
DebugLinkGetPairingInfo = 9009
DebugLinkPairingInfo = 9010
EthereumGetPublicKey = 450 EthereumGetPublicKey = 450
EthereumPublicKey = 451 EthereumPublicKey = 451
EthereumGetAddress = 56 EthereumGetAddress = 56
@ -4222,6 +4256,52 @@ class DebugLinkState(protobuf.MessageType):
self.mnemonic_type = mnemonic_type self.mnemonic_type = mnemonic_type
class DebugLinkGetPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9009
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
nfc_secret_host: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.nfc_secret_host = nfc_secret_host
class DebugLinkPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9010
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
code_entry_code: Optional["int"] = None,
code_qr_code: Optional["bytes"] = None,
nfc_secret_trezor: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.code_entry_code = code_entry_code
self.code_qr_code = code_qr_code
self.nfc_secret_trezor = nfc_secret_trezor
class DebugLinkStop(protobuf.MessageType): class DebugLinkStop(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 103 MESSAGE_WIRE_TYPE = 103
@ -7976,8 +8056,68 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount self.amount = amount
class ThpCredentialMetadata(protobuf.MessageType): class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
protocol_version_major: Optional["int"] = None,
protocol_version_minor: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.protocol_version_major = protocol_version_major
self.protocol_version_minor = protocol_version_minor
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1006
FIELDS = { FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
} }
@ -7990,6 +8130,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
self.host_name = host_name self.host_name = host_name
class ThpPairingRequestApproved(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1007
class ThpSelectMethod(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
selected_pairing_method: Optional["ThpPairingMethod"] = None,
) -> None:
self.selected_pairing_method = selected_pairing_method
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
tag: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcTagHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcTagTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
self.autoconnect = autoconnect
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_name = host_name
self.autoconnect = autoconnect
class ThpPairingCredential(protobuf.MessageType): class ThpPairingCredential(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None MESSAGE_WIRE_TYPE = None
FIELDS = { FIELDS = {

View File

@ -22,6 +22,7 @@ import typing as t
import requests import requests
from typing_extensions import Self from typing_extensions import Self
from ..client import ProtocolVersion
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DeviceIsBusy, Transport, TransportException from . import DeviceIsBusy, Transport, TransportException
@ -34,6 +35,7 @@ TREZORD_HOST = "http://127.0.0.1:21325"
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25) TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
CONNECTION = requests.Session() CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
@ -66,6 +68,44 @@ def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN return get_bridge_version() < TREZORD_VERSION_MODERN
def supports_protocolV2() -> bool:
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = ProtocolVersion.PROTOCOL_V1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.open()
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
response = transport.read_chunk()
response_type = int.from_bytes(response[:2], "big")
response_data = response[2:]
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = ProtocolVersion.PROTOCOL_V2
return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
if not is_valid:
LOG.warning("Detected unsupported Bridge transport!")
return is_valid
def filter_invalid_bridge_transports(
transports: t.Iterable["BridgeTransport"],
) -> t.Sequence["BridgeTransport"]:
"""Filters out invalid bridge transports. Keeps only valid ones."""
return [t for t in transports if _is_transport_valid(t)]
class BridgeHandle: class BridgeHandle:
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
self.transport = transport self.transport = transport
@ -160,9 +200,12 @@ class BridgeTransport(Transport):
) -> t.Iterable["BridgeTransport"]: ) -> t.Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
return [ return filter_invalid_bridge_transports(
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() [
] BridgeTransport(dev, legacy)
for dev in call_bridge("enumerate").json()
]
)
except Exception: except Exception:
return [] return []

View File

@ -7,6 +7,7 @@ from .. import exceptions, messages, models
from ..client import MAX_PIN_LENGTH from ..client import MAX_PIN_LENGTH
from ..protobuf import MessageType from ..protobuf import MessageType
from .thp.protocol_v1 import ProtocolV1Channel from .thp.protocol_v1 import ProtocolV1Channel
from .thp.protocol_v2 import ProtocolV2Channel
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
@ -235,3 +236,50 @@ def derive_seed(session: Session, passphrase: str | object) -> None:
_passphrase_ack=ack, _passphrase_ack=ack,
) )
session.refresh_features() session.refresh_features()
class SessionV2(Session):
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | None,
derive_cardano: bool,
session_id: int = 0,
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2Channel)
session = cls(client, session_id.to_bytes(1, "big"))
session.call(
messages.ThpCreateNewSession(
passphrase=passphrase, derive_cardano=derive_cardano
),
expect=messages.Success,
)
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
from ..debuglink import TrezorClientDebugLink
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2Channel)
helper_debug = None
if isinstance(client, TrezorClientDebugLink):
helper_debug = client.debug
self.channel: ProtocolV2Channel = client.protocol.get_channel(helper_debug)
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:
LOG.debug("writing message %s", type(msg))
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
return msg
def update_id_and_sid(self, id: bytes) -> None:
self._id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -0,0 +1,102 @@
# from storage.cache_thp import ChannelCache
# from trezor import log
# from trezor.wire.thp import ThpError
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
# """
# Checks if:
# - an ACK message is expected
# - the received ACK message acknowledges correct sequence number (bit)
# """
# if not _is_ack_expected(cache):
# return False
# if not _has_ack_correct_sync_bit(cache, ack_bit):
# return False
# return True
# def _is_ack_expected(cache: ChannelCache) -> bool:
# is_expected: bool = not is_sending_allowed(cache)
# if __debug__ and not is_expected:
# log.debug(__name__, "Received unexpected ACK message")
# return is_expected
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
# if __debug__ and not is_correct:
# log.debug(__name__, "Received ACK message with wrong ack bit")
# return is_correct
# def is_sending_allowed(cache: ChannelCache) -> bool:
# """
# Checks whether sending a message in the provided channel is allowed.
# Note: Sending a message in a channel before receipt of ACK message for the previously
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
# """
# return bool(cache.sync >> 7)
# def get_send_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the sequential number (bit) of the next message to be sent
# in the provided channel.
# """
# return (cache.sync & 0x20) >> 5
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the (expected) sequential number (bit) of the next message
# to be received in the provided channel.
# """
# return (cache.sync & 0x40) >> 6
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
# """
# Set the flag whether sending a message in this channel is allowed or not.
# """
# cache.sync &= 0x7F
# if sending_allowed:
# cache.sync |= 0x80
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# """
# Set the expected sequential number (bit) of the next message to be received
# in the provided channel
# """
# if __debug__:
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected receive sync bit")
# # set second bit to "seq_bit" value
# cache.sync &= 0xBF
# if seq_bit:
# cache.sync |= 0x40
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected send seq bit")
# if __debug__:
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# # set third bit to "seq_bit" value
# cache.sync &= 0xDF
# if seq_bit:
# cache.sync |= 0x20
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
# """
# Set the sequential bit of the "next message to be send" to the opposite value,
# i.e. 1 -> 0 and 0 -> 1
# """
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from binascii import hexlify
class ChannelData:
def __init__(
self,
protocol_version_major: int,
protocol_version_minor: int,
transport_path: str,
channel_id: int,
key_request: bytes,
key_response: bytes,
nonce_request: int,
nonce_response: int,
sync_bit_send: int,
sync_bit_receive: int,
handshake_hash: bytes,
) -> None:
self.protocol_version_major: int = protocol_version_major
self.protocol_version_minor: int = protocol_version_minor
self.transport_path: str = transport_path
self.channel_id: int = channel_id
self.key_request: str = hexlify(key_request).decode()
self.key_response: str = hexlify(key_response).decode()
self.nonce_request: int = nonce_request
self.nonce_response: int = nonce_response
self.sync_bit_receive: int = sync_bit_receive
self.sync_bit_send: int = sync_bit_send
self.handshake_hash: str = hexlify(handshake_hash).decode()
def to_dict(self):
return {
"protocol_version_major": self.protocol_version_major,
"protocol_version_minor": self.protocol_version_minor,
"transport_path": self.transport_path,
"channel_id": self.channel_id,
"key_request": self.key_request,
"key_response": self.key_response,
"nonce_request": self.nonce_request,
"nonce_response": self.nonce_response,
"sync_bit_send": self.sync_bit_send,
"sync_bit_receive": self.sync_bit_receive,
"handshake_hash": self.handshake_hash,
}

View File

@ -0,0 +1,146 @@
from __future__ import annotations
import json
import logging
import os
import typing as t
from .channel_data import ChannelData
LOG = logging.getLogger(__name__)
db: "ChannelDatabase | None" = None
def get_channel_db() -> ChannelDatabase:
if db is None:
set_channel_database(should_not_store=True)
assert db is not None
return db
if t.TYPE_CHECKING:
from .protocol_and_channel import Channel
from .protocol_v2 import ProtocolV2Channel
class ChannelDatabase:
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def save_channel(self, new_channel: Channel): ...
def remove_channel(self, transport_path: str) -> None: ...
class DummyChannelDatabase(ChannelDatabase):
def load_stored_channels(self) -> t.List[ChannelData]:
return []
def clear_stored_channels(self) -> None:
pass
def save_channel(self, new_channel: Channel):
pass
def remove_channel(self, transport_path: str) -> None:
pass
class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()
def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self._read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
def _read_all_channels(self) -> t.List:
ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)
def _save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(self, new_channel: ProtocolV2Channel):
LOG.debug("save channel")
channels = self._read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self._save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self._save_all_channels(channels)
def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self._read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self._save_all_channels(remaining_channels)
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
return ChannelData(
protocol_version_major=dict["protocol_version_minor"],
protocol_version_minor=dict["protocol_version_major"],
transport_path=dict["transport_path"],
channel_id=dict["channel_id"],
key_request=bytes.fromhex(dict["key_request"]),
key_response=bytes.fromhex(dict["key_response"]),
nonce_request=dict["nonce_request"],
nonce_response=dict["nonce_response"],
sync_bit_send=dict["sync_bit_send"],
sync_bit_receive=dict["sync_bit_receive"],
handshake_hash=bytes.fromhex(dict["handshake_hash"]),
)
def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(file_path, "w") as f:
json.dump([], f)
def set_channel_database(should_not_store: bool):
global db
if should_not_store:
db = DummyChannelDatabase()
else:
from platformdirs import user_cache_dir
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
db = JsonChannelDatabase(DATA_PATH)

View File

@ -0,0 +1,19 @@
import zlib
CHECKSUM_LENGTH = 4
def compute(data: bytes) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,63 @@
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise Exception("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise Exception("Unexpected acknowledgement bit")
def get_seq_bit(ctrl_byte: int) -> int:
return (ctrl_byte & 0x10) >> 4
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_error(ctrl_byte: int) -> bool:
return ctrl_byte == _ERROR
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,40 @@
import typing as t
from hashlib import sha512
from . import curve25519
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
random_bytes: t.Callable[[int], bytes]
def __init__(self, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.shared_secret: bytes
self.host_private_key: bytes
self.host_public_key: bytes
def generate_keys_and_secret(
self, code_code_entry: bytes, trezor_public_key: bytes
) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = curve25519.elligator2(pregenerator)
self.host_private_key = self.random_bytes(32)
self.host_public_key = curve25519.multiply(self.host_private_key, generator)
self.shared_secret = curve25519.multiply(
self.host_private_key, trezor_public_key
)

View File

@ -0,0 +1,159 @@
from typing import Tuple
p = 2**255 - 19
J = 486662
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
a24 = 121666 # (J + 2) // 4
def decode_scalar(scalar: bytes) -> int:
# decodeScalar25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(scalar) != 32:
raise ValueError("Invalid length of scalar")
array = bytearray(scalar)
array[0] &= 248
array[31] &= 127
array[31] |= 64
return int.from_bytes(array, "little")
def decode_coordinate(coordinate: bytes) -> int:
# decodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(coordinate) != 32:
raise ValueError("Invalid length of coordinate")
array = bytearray(coordinate)
array[-1] &= 0x7F
return int.from_bytes(array, "little") % p
def encode_coordinate(coordinate: int) -> bytes:
# encodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
return coordinate.to_bytes(32, "little")
def get_private_key(secret: bytes) -> bytes:
return decode_scalar(secret).to_bytes(32, "little")
def get_public_key(private_key: bytes) -> bytes:
base_point = int.to_bytes(9, 32, "little")
return multiply(private_key, base_point)
def multiply(private_scalar: bytes, public_point: bytes):
# X25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
def ladder_operation(
x1: int, x2: int, z2: int, x3: int, z3: int
) -> Tuple[int, int, int, int]:
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
# (x4, z4) = 2 * (x2, z2)
# (x5, z5) = (x2, z2) + (x3, z3)
# where (x1, 1) = (x3, z3) - (x2, z2)
a = (x2 + z2) % p
aa = (a * a) % p
b = (x2 - z2) % p
bb = (b * b) % p
e = (aa - bb) % p
c = (x3 + z3) % p
d = (x3 - z3) % p
da = (d * a) % p
cb = (c * b) % p
t0 = (da + cb) % p
x5 = (t0 * t0) % p
t1 = (da - cb) % p
t2 = (t1 * t1) % p
z5 = (x1 * t2) % p
x4 = (aa * bb) % p
t3 = (a24 * e) % p
t4 = (bb + t3) % p
z4 = (e * t4) % p
return x4, z4, x5, z5
def conditional_swap(first: int, second: int, condition: int):
# Returns (second, first) if condition is true and (first, second) otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
first & true_mask
)
k = decode_scalar(private_scalar)
u = decode_coordinate(public_point)
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for i in reversed(range(256)):
bit = (k >> i) & 1
swap = bit ^ swap
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
swap = bit
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
x = pow(z_2, p - 2, p) * x_2 % p
return encode_coordinate(x)
def elligator2(point: bytes) -> bytes:
# map_to_curve_elligator2_curve25519 from
# https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt
def conditional_move(first: int, second: int, condition: bool):
# Returns second if condition is true and first otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask)
u = decode_coordinate(point)
tv1 = (u * u) % p
tv1 = (2 * tv1) % p
xd = (tv1 + 1) % p
x1n = (-J) % p
tv2 = (xd * xd) % p
gxd = (tv2 * xd) % p
gx1 = (J * tv1) % p
gx1 = (gx1 * x1n) % p
gx1 = (gx1 + tv2) % p
gx1 = (gx1 * x1n) % p
tv3 = (gxd * gxd) % p
tv2 = (tv3 * tv3) % p
tv3 = (tv3 * gxd) % p
tv3 = (tv3 * gx1) % p
tv2 = (tv2 * tv3) % p
y11 = pow(tv2, c4, p)
y11 = (y11 * tv3) % p
y12 = (y11 * c3) % p
tv2 = (y11 * y11) % p
tv2 = (tv2 * gxd) % p
e1 = tv2 == gx1
y1 = conditional_move(y12, y11, e1)
x2n = (x1n * tv1) % p
tv2 = (y1 * y1) % p
tv2 = (tv2 * gxd) % p
e3 = tv2 == gx1
xn = conditional_move(x2n, x1n, e3)
x = xn * pow(xd, p - 2, p) % p
return encode_coordinate(x)

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class MessageHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,490 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
DEFAULT_SESSION_ID: int = 0
if t.TYPE_CHECKING:
from ...debuglink import DebugLink
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2Channel(Channel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
handshake_hash: bytes
_has_valid_channel: bool = False
_features: messages.Features | None = None
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
self._has_valid_channel = True
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
if not self._has_valid_channel:
self._establish_new_channel(helper_debug)
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version_major=2,
protocol_version_minor=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
handshake_hash=self.handshake_hash,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = DEFAULT_SESSION_ID
self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _send_message(
self,
message: protobuf.MessageType,
session_id: int = DEFAULT_SESSION_ID,
):
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(session_id, message_type, message_data)
self._read_ack()
def _read_message(self, message_type: type[MT]) -> MT:
_, msg_type, msg_data = self.read_and_decrypt()
msg = self.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce)
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
self.channel_id = cid
self.device_properties = dp
def _send_channel_allocation_request(self, nonce: bytes):
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
nonce,
)
def _read_channel_allocation_response(
self, expected_nonce: bytes
) -> tuple[int, bytes]:
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, expected_nonce
):
raise Exception("Invalid channel allocation response.")
channel_id = int.from_bytes(payload[8:10], "big")
device_properties = payload[10:]
return (channel_id, device_properties)
def _do_handshake(
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack()
init_response = self._read_handshake_init_response()
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# TODO check noise_tag is valid
ck = self._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
credential,
host_static_privkey,
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
if header.ctrl_byte == 0x42:
if payload == b"\x05":
raise exceptions.DeviceLockedException()
if not header.is_handshake_init_response():
LOG.debug("Received message is not a valid handshake init response message")
click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
return payload
def _send_handshake_completion_request(
self,
host_ephemeral_pubkey: bytes,
host_ephemeral_privkey: bytes,
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
host_static_privkey: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials
if host_static_privkey is not None and credential is not None:
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
else:
credential = None
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(
temp_host_static_privkey
)
host_static_privkey = temp_host_static_privkey
host_static_pubkey = temp_host_static_pubkey
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
host_pairing_credential=credential,
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload[:-16])
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
self.handshake_hash = h
return ck
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
self._send_message(messages.ThpPairingRequest())
self._read_message(messages.ButtonRequest)
self._send_message(messages.ButtonAck())
if helper_debug is not None:
helper_debug.press_yes()
self._read_message(messages.ThpPairingRequestApproved)
self._send_message(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
)
)
self._read_message(messages.ThpEndResponse)
self._has_valid_channel = True
def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self):
LOG.debug("sending ack 1")
header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
self.sync_bit_send = 1 - self.sync_bit_send
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
# TODO fix this recursion
return self.read_and_decrypt()
if control_byte.is_error(header.ctrl_byte):
# TODO check for different channel
err = _get_error_from_int(raw_payload[0])
raise Exception("Received ThpError: " + err)
if not header.is_encrypted_transport():
click.echo(
"Trying to decrypt not encrypted message! ("
+ hexlify(header.to_bytes_init() + raw_payload).decode()
+ ")",
err=True,
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(
"--> Get sequence bit %d %s %s",
control_byte.get_seq_bit(header.ctrl_byte),
"from control byte",
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
)
if control_byte.get_seq_bit(header.ctrl_byte):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
session_id,
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[MessageHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
click.echo(
"Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
err=True,
)
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
click.echo(
"Received message is not a channel allocation response", err=True
)
return False
if len(payload) < 10:
click.echo("Invalid channel allocation response payload", err=True)
return False
if payload[:8] != original_nonce:
click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False
return True
def _get_error_from_int(error_code: int) -> str:
# TODO FIXME improve this (ThpErrorType)
if error_code == 1:
return "TRANSPORT BUSY"
if error_code == 2:
return "UNALLOCATED CHANNEL"
if error_code == 3:
return "DECRYPTION FAILED"
if error_code == 4:
return "INVALID DATA"
if error_code == 5:
return "DEVICE LOCKED"
raise Exception("Not Implemented error case")

View File

@ -0,0 +1,97 @@
import struct
from typing import Tuple
from .. import Transport
from ..thp import checksum
from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
transport.open()
buffer = bytearray(transport_payload)
if transport.CHUNK_SIZE is None:
transport.write_chunk(buffer)
return
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid))
data_len = header.data_length - checksum.CHECKSUM_LENGTH
msg_data = buffer[:data_len]
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
return (header, msg_data, chksum)
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]

View File

@ -76,9 +76,9 @@ if __name__ == "__main__":
# change PIN # change PIN
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10))) new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
client.set_input_flow(pin_input_flow(client, last_pin, new_pin)) session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
device.change_pin(client) device.change_pin(client)
client.set_input_flow(None) session.set_input_flow(None)
last_pin = new_pin last_pin = new_pin
print(f"iteration {i}") print(f"iteration {i}")

View File

@ -22,7 +22,9 @@ import time
import typing as t import typing as t
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from time import sleep
import cryptography
import pytest import pytest
import xdist import xdist
from _pytest.python import IdMaker from _pytest.python import IdMaker
@ -315,11 +317,23 @@ def _client_unlocked(
should_format = sd_marker.kwargs.get("formatted", True) should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format) _raw_client.debug.erase_sd_card(format=should_format)
if _raw_client.is_invalidated: while True:
_raw_client = _raw_client.get_new_client() try:
session = _raw_client.get_seedless_session() if _raw_client.is_invalidated:
wipe_device(session) _raw_client = _raw_client.get_new_client()
# sleep(1.5) # Makes tests more stable (wait for wipe to finish) session = _raw_client.get_seedless_session()
wipe_device(session)
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
break
except cryptography.exceptions.InvalidTag:
# Get a new client
_raw_client = _get_raw_client(request)
_raw_client.protocol = None
_raw_client.__init__(
transport=_raw_client.transport,
auto_interact=_raw_client.debug.allow_interactions,
)
if not _raw_client.features.bootloader_mode: if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features() _raw_client.refresh_features()

View File

@ -8,6 +8,7 @@ import pytest
from _pytest.nodes import Node from _pytest.nodes import Node
from _pytest.outcomes import Failed from _pytest.outcomes import Failed
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from . import common from . import common