1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

feat(python): implement session based trezorlib

This commit is contained in:
M1nd3r 2025-02-04 15:19:56 +01:00
parent 7b09cde110
commit b92961cb28
4 changed files with 429 additions and 12 deletions

View File

@ -27,6 +27,7 @@ from .transport import Transport, get_transport
from .transport.thp.channel_data import ChannelData
from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel
from .transport.thp.protocol_v2 import ProtocolV2Channel
if t.TYPE_CHECKING:
from .transport.session import Session
@ -86,6 +87,8 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV1Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else:
self._protocol_version = ProtocolVersion.UNKNOWN
@ -98,7 +101,29 @@ class TrezorClient:
) -> TrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
protocol = ProtocolV1Channel(transport, protobuf_mapping, channel_data)
protocol_v1 = ProtocolV1Channel(transport, protobuf_mapping)
if channel_data.protocol_version_major >= 2:
try:
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
except Exception as e:
print(type(e))
response = protocol_v1.read()
if (
isinstance(response, messages.Failure)
and response.code == messages.FailureType.InvalidProtocol
):
protocol = ProtocolV2Channel(transport, protobuf_mapping, channel_data)
protocol.write(0, messages.Ping())
response = protocol.read(0)
if not isinstance(response, messages.Success):
LOG.debug("Failed to resume ProtocolV2Channel")
raise Exception("Failed to resume ProtocolV2Channel")
LOG.debug("Protocol V2 detected - can be resumed")
else:
LOG.debug("Failed to resume ProtocolV2Channel")
raise Exception("Failed to resume ProtocolV2Channel")
else:
protocol = ProtocolV1Channel(transport, protobuf_mapping, channel_data)
return TrezorClient(transport, protobuf_mapping, protocol)
def get_session(
@ -112,23 +137,28 @@ class TrezorClient:
Will fail if the device is not initialized
"""
from .transport.session import SessionV1
from .transport.session import SessionV1, SessionV2
if isinstance(self.protocol, ProtocolV1Channel):
return SessionV1.new(self, passphrase, derive_cardano)
raise NotImplementedError
if isinstance(self.protocol, ProtocolV2Channel):
assert isinstance(passphrase, str) or passphrase is None
return SessionV2.new(self, passphrase, derive_cardano, session_id)
raise NotImplementedError # TODO
def resume_session(self, session: Session):
"""
Note: this function potentially modifies the input session.
"""
from .debuglink import SessionDebugWrapper
from .transport.session import SessionV1
from .transport.session import SessionV1, SessionV2
if isinstance(session, SessionDebugWrapper):
session = session._session
if isinstance(session, SessionV1):
if isinstance(session, SessionV2):
return session
elif isinstance(session, SessionV1):
session.init_session()
return session
@ -136,7 +166,7 @@ class TrezorClient:
raise NotImplementedError
def get_seedless_session(self, new_session: bool = False) -> Session:
from .transport.session import SessionV1
from .transport.session import SessionV1, SessionV2
if not new_session and self._seedless_session is not None:
return self._seedless_session
@ -146,6 +176,8 @@ class TrezorClient:
passphrase="",
derive_cardano=False,
)
elif isinstance(self.protocol, ProtocolV2Channel):
self._seedless_session = SessionV2(client=self, id=b"\x00")
assert self._seedless_session is not None
return self._seedless_session
@ -197,8 +229,12 @@ class TrezorClient:
protocol.write(messages.Initialize())
_ = protocol.read()
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2Channel(self.transport, self.mapping)
return protocol

View File

@ -534,6 +534,25 @@ class DebugLink:
raise TrezorFailure(result)
return result
def pairing_info(
self,
thp_channel_id: bytes | None = None,
handshake_hash: bytes | None = None,
nfc_secret_host: bytes | None = None,
) -> messages.DebugLinkPairingInfo:
result = self._call(
messages.DebugLinkGetPairingInfo(
channel_id=thp_channel_id,
handshake_hash=handshake_hash,
nfc_secret_host=nfc_secret_host,
)
)
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
result = self._read()
if isinstance(result, messages.Failure):
raise TrezorFailure(result)
return result
def read_layout(self, wait: bool | None = None) -> LayoutContent:
"""
Force waiting for the layout by setting `wait=True`. Force not waiting by
@ -1326,8 +1345,9 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(None, None)
try:
if isinstance(session, SessionV1) or isinstance(
session, SessionDebugWrapper
if isinstance(session, SessionV1) or (
isinstance(session, SessionDebugWrapper)
and isinstance(session._session, SessionV1)
):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
@ -1335,7 +1355,7 @@ class TrezorClientDebugLink(TrezorClient):
if passphrase is None:
passphrase = session.passphrase
else:
raise NotImplementedError
passphrase = session.passphrase
except Cancelled:
session.call_raw(messages.Cancel())
raise

View File

@ -67,6 +67,16 @@ class ProtobufMapping:
protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
"""Serialize a Python protobuf class.
Returns the byte representation of the protobuf message.
"""
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type]
@ -82,8 +92,9 @@ class ProtobufMapping:
mapping = cls()
message_types = getattr(module, "MessageType")
thp_message_types = getattr(module, "ThpMessageType")
for entry in message_types:
for entry in (*message_types, *thp_message_types):
msg_class = getattr(module, entry.name, None)
if msg_class is None:
raise ValueError(

View File

@ -43,6 +43,10 @@ class FailureType(IntEnum):
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
BufferError = 17
DeviceIsBusy = 18
FirmwareError = 99
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpPairingRequest = 1006
ThpPairingRequestApproved = 1007
ThpSelectMethod = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceTrezor = 1018
ThpCodeEntryCpaceHostTag = 1019
ThpCodeEntrySecret = 1020
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcTagHost = 1032
ThpNfcTagTrezor = 1033
class ThpPairingMethod(IntEnum):
SkipPairing = 1
CodeEntry = 2
QrCode = 3
NFC = 4
class MessageType(IntEnum):
Initialize = 0
Ping = 1
@ -500,6 +532,8 @@ class MessageType(IntEnum):
DebugLinkWatchLayout = 9006
DebugLinkResetDebugEvents = 9007
DebugLinkOptigaSetSecMax = 9008
DebugLinkGetPairingInfo = 9009
DebugLinkPairingInfo = 9010
EthereumGetPublicKey = 450
EthereumPublicKey = 451
EthereumGetAddress = 56
@ -4203,6 +4237,52 @@ class DebugLinkState(protobuf.MessageType):
self.mnemonic_type = mnemonic_type
class DebugLinkGetPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9009
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
nfc_secret_host: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.nfc_secret_host = nfc_secret_host
class DebugLinkPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9010
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
code_entry_code: Optional["int"] = None,
code_qr_code: Optional["bytes"] = None,
nfc_secret_trezor: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.code_entry_code = code_entry_code
self.code_qr_code = code_qr_code
self.nfc_secret_trezor = nfc_secret_trezor
class DebugLinkStop(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 103
@ -7863,8 +7943,68 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount
class ThpCredentialMetadata(protobuf.MessageType):
class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
protocol_version_major: Optional["int"] = None,
protocol_version_minor: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.protocol_version_major = protocol_version_major
self.protocol_version_minor = protocol_version_minor
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1006
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
@ -7877,6 +8017,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
self.host_name = host_name
class ThpPairingRequestApproved(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1007
class ThpSelectMethod(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
selected_pairing_method: Optional["ThpPairingMethod"] = None,
) -> None:
self.selected_pairing_method = selected_pairing_method
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
tag: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcTagHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcTagTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
self.autoconnect = autoconnect
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_name = host_name
self.autoconnect = autoconnect
class ThpPairingCredential(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {