mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-21 00:08:46 +00:00
chore(core): adapt emu.py to the new trezorlib
[no changelog]
This commit is contained in:
parent
749ad9149c
commit
49ff6e4830
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
from trezorlib.client import TrezorClient
|
from trezorlib.client import TrezorClient
|
||||||
from trezorlib.transport_hid import HidTransport
|
from trezorlib.transport.hid import HidTransport
|
||||||
|
|
||||||
devices = HidTransport.enumerate()
|
devices = HidTransport.enumerate()
|
||||||
if len(devices) > 0:
|
if len(devices) > 0:
|
||||||
|
@ -9,3 +9,4 @@ construct>=2.9,!=2.10.55
|
|||||||
typing_extensions>=4.7.1
|
typing_extensions>=4.7.1
|
||||||
construct-classes>=0.1.2
|
construct-classes>=0.1.2
|
||||||
cryptography>=41
|
cryptography>=41
|
||||||
|
platformdirs>=2
|
||||||
|
@ -30,7 +30,7 @@ from .. import exceptions, transport, ui
|
|||||||
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
|
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
|
||||||
from ..messages import Capability
|
from ..messages import Capability
|
||||||
from ..transport import Transport
|
from ..transport import Transport
|
||||||
from ..transport.session import Session, SessionV1
|
from ..transport.session import Session, SessionV1, SessionV2
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -135,10 +135,13 @@ class TrezorConnection:
|
|||||||
|
|
||||||
# Try resume session from id
|
# Try resume session from id
|
||||||
if self.session_id is not None:
|
if self.session_id is not None:
|
||||||
if client.protocol_version is ProtocolVersion.V1:
|
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||||
session = SessionV1.resume_from_id(
|
session = SessionV1.resume_from_id(
|
||||||
client=client, session_id=self.session_id
|
client=client, session_id=self.session_id
|
||||||
)
|
)
|
||||||
|
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||||
|
session = SessionV2(client, self.session_id)
|
||||||
|
# TODO fix resumption on THP
|
||||||
else:
|
else:
|
||||||
raise Exception("Unsupported client protocol", client.protocol_version)
|
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||||
if must_resume:
|
if must_resume:
|
||||||
@ -267,6 +270,11 @@ class TrezorConnection:
|
|||||||
empty_passphrase=empty_passphrase,
|
empty_passphrase=empty_passphrase,
|
||||||
must_resume=must_resume,
|
must_resume=must_resume,
|
||||||
)
|
)
|
||||||
|
except exceptions.DeviceLockedException:
|
||||||
|
click.echo(
|
||||||
|
"Device is locked, enter a pin on the device.",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
except transport.DeviceIsBusy:
|
except transport.DeviceIsBusy:
|
||||||
click.echo("Device is in use by another process.")
|
click.echo("Device is in use by another process.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -26,6 +26,7 @@ from .tools import parse_path
|
|||||||
from .transport import Transport, get_transport
|
from .transport import Transport, get_transport
|
||||||
from .transport.thp.protocol_and_channel import Channel
|
from .transport.thp.protocol_and_channel import Channel
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
from .transport.thp.protocol_v2 import ProtocolV2Channel
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from .transport.session import Session, SessionV1
|
from .transport.session import Session, SessionV1
|
||||||
@ -93,6 +94,8 @@ class TrezorClient:
|
|||||||
|
|
||||||
if isinstance(self.protocol, ProtocolV1Channel):
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
self._protocol_version = ProtocolVersion.V1
|
self._protocol_version = ProtocolVersion.V1
|
||||||
|
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||||
|
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown protocol version")
|
raise Exception("Unknown protocol version")
|
||||||
|
|
||||||
@ -121,8 +124,18 @@ class TrezorClient:
|
|||||||
derive_cardano=derive_cardano,
|
derive_cardano=derive_cardano,
|
||||||
)
|
)
|
||||||
derive_seed(session, passphrase)
|
derive_seed(session, passphrase)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
if isinstance(self.protocol, ProtocolV2Channel):
|
||||||
|
from .transport.session import SessionV2
|
||||||
|
|
||||||
|
assert isinstance(passphrase, str) or passphrase is None
|
||||||
|
session_id = 1 # TODO fix this with ProtocolV2 session rework
|
||||||
|
if session_id is not None:
|
||||||
|
sid = int.from_bytes(session_id, "big")
|
||||||
|
else:
|
||||||
|
sid = 1
|
||||||
|
assert 0 <= sid <= 255
|
||||||
|
return SessionV2.new(self, passphrase, derive_cardano, sid)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_seedless_session(self) -> Session:
|
def get_seedless_session(self) -> Session:
|
||||||
@ -174,6 +187,15 @@ class TrezorClient:
|
|||||||
|
|
||||||
def _get_protocol(self) -> Channel:
|
def _get_protocol(self) -> Channel:
|
||||||
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
|
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
|
||||||
|
|
||||||
|
protocol.write(messages.Initialize())
|
||||||
|
|
||||||
|
response = protocol.read()
|
||||||
|
self.transport.close()
|
||||||
|
if isinstance(response, messages.Failure):
|
||||||
|
if response.code == messages.FailureType.InvalidProtocol:
|
||||||
|
LOG.debug("Protocol V2 detected")
|
||||||
|
protocol = ProtocolV2Channel(self.transport, self.mapping)
|
||||||
return protocol
|
return protocol
|
||||||
|
|
||||||
def is_outdated(self) -> bool:
|
def is_outdated(self) -> bool:
|
||||||
|
@ -32,13 +32,20 @@ from pathlib import Path
|
|||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import btc, mapping, messages, models, protobuf
|
from . import btc, mapping, messages, models, protobuf
|
||||||
from .client import ProtocolVersion, TrezorClient
|
from .client import (
|
||||||
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
|
MAX_PASSPHRASE_LENGTH,
|
||||||
|
MAX_PIN_LENGTH,
|
||||||
|
PASSPHRASE_ON_DEVICE,
|
||||||
|
ProtocolVersion,
|
||||||
|
TrezorClient,
|
||||||
|
)
|
||||||
|
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import DebugWaitType
|
from .messages import Capability, DebugWaitType
|
||||||
|
from .protobuf import MessageType
|
||||||
from .tools import parse_path
|
from .tools import parse_path
|
||||||
from .transport import Timeout
|
from .transport import Timeout
|
||||||
from .transport.session import Session
|
from .transport.session import Session, SessionV1, derive_seed
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
@ -522,6 +529,25 @@ class DebugLink:
|
|||||||
raise TrezorFailure(result)
|
raise TrezorFailure(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def pairing_info(
|
||||||
|
self,
|
||||||
|
thp_channel_id: bytes | None = None,
|
||||||
|
handshake_hash: bytes | None = None,
|
||||||
|
nfc_secret_host: bytes | None = None,
|
||||||
|
) -> messages.DebugLinkPairingInfo:
|
||||||
|
result = self._call(
|
||||||
|
messages.DebugLinkGetPairingInfo(
|
||||||
|
channel_id=thp_channel_id,
|
||||||
|
handshake_hash=handshake_hash,
|
||||||
|
nfc_secret_host=nfc_secret_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
|
||||||
|
result = self._read()
|
||||||
|
if isinstance(result, messages.Failure):
|
||||||
|
raise TrezorFailure(result)
|
||||||
|
return result
|
||||||
|
|
||||||
def read_layout(self, wait: bool | None = None) -> LayoutContent:
|
def read_layout(self, wait: bool | None = None) -> LayoutContent:
|
||||||
"""
|
"""
|
||||||
Force waiting for the layout by setting `wait=True`. Force not waiting by
|
Force waiting for the layout by setting `wait=True`. Force not waiting by
|
||||||
@ -788,6 +814,7 @@ class DebugUI:
|
|||||||
|
|
||||||
def __init__(self, debuglink: DebugLink) -> None:
|
def __init__(self, debuglink: DebugLink) -> None:
|
||||||
self.debuglink = debuglink
|
self.debuglink = debuglink
|
||||||
|
self.pins: t.Iterator[str] | None = None
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
@ -1049,16 +1076,20 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
self.sync_responses()
|
self.sync_responses()
|
||||||
|
|
||||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
def __getattr__(self, name: str) -> t.Any:
|
||||||
# and know the supported debug capabilities
|
return getattr(self._session, name)
|
||||||
self.debug.model = self.model
|
|
||||||
self.debug.version = self.version
|
def __setattr__(self, name: str, value: t.Any) -> None:
|
||||||
|
if hasattr(self._session, name):
|
||||||
|
setattr(self._session, name, value)
|
||||||
|
else:
|
||||||
|
self.__dict__[name] = value
|
||||||
|
|
||||||
self.reset_debug_features()
|
self.reset_debug_features()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layout_type(self) -> LayoutType:
|
def protocol_version(self) -> int:
|
||||||
return self.debug.layout_type
|
return self.client.protocol_version
|
||||||
|
|
||||||
def get_new_client(self) -> TrezorClientDebugLink:
|
def get_new_client(self) -> TrezorClientDebugLink:
|
||||||
new_client = TrezorClientDebugLink(
|
new_client = TrezorClientDebugLink(
|
||||||
@ -1260,8 +1291,10 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
actual_responses = self.actual_responses
|
actual_responses = self.actual_responses
|
||||||
|
|
||||||
# grab a copy of the inputflow generator to raise an exception through it
|
# grab a copy of the inputflow generator to raise an exception through it
|
||||||
if isinstance(self.ui, DebugUI):
|
if isinstance(self.client, TrezorClientDebugLink) and isinstance(
|
||||||
input_flow = self.ui.input_flow
|
self.client.ui, DebugUI
|
||||||
|
):
|
||||||
|
input_flow = self.client.ui.input_flow
|
||||||
else:
|
else:
|
||||||
input_flow = None
|
input_flow = None
|
||||||
|
|
||||||
|
@ -109,3 +109,7 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
|
|||||||
"""Tried to derive seed on uninitialized device.
|
"""Tried to derive seed on uninitialized device.
|
||||||
|
|
||||||
To communicate with uninitialized device, use seedless session instead."""
|
To communicate with uninitialized device, use seedless session instead."""
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceLockedException(TrezorException):
|
||||||
|
pass
|
||||||
|
@ -70,6 +70,16 @@ class ProtobufMapping:
|
|||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
return wire_type, buf.getvalue()
|
return wire_type, buf.getvalue()
|
||||||
|
|
||||||
|
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||||
|
"""Serialize a Python protobuf class.
|
||||||
|
|
||||||
|
Returns the byte representation of the protobuf message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||||
"""Deserialize a protobuf message into a Python class."""
|
"""Deserialize a protobuf message into a Python class."""
|
||||||
cls = self.type_to_class[msg_wire_type]
|
cls = self.type_to_class[msg_wire_type]
|
||||||
@ -85,8 +95,9 @@ class ProtobufMapping:
|
|||||||
mapping = cls()
|
mapping = cls()
|
||||||
|
|
||||||
message_types = getattr(module, "MessageType")
|
message_types = getattr(module, "MessageType")
|
||||||
|
thp_message_types = getattr(module, "ThpMessageType")
|
||||||
|
|
||||||
for entry in message_types:
|
for entry in (*message_types, *thp_message_types):
|
||||||
msg_class = getattr(module, entry.name, None)
|
msg_class = getattr(module, entry.name, None)
|
||||||
if msg_class is None:
|
if msg_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
352
python/src/trezorlib/messages.py
generated
352
python/src/trezorlib/messages.py
generated
@ -43,6 +43,10 @@ class FailureType(IntEnum):
|
|||||||
PinMismatch = 12
|
PinMismatch = 12
|
||||||
WipeCodeMismatch = 13
|
WipeCodeMismatch = 13
|
||||||
InvalidSession = 14
|
InvalidSession = 14
|
||||||
|
ThpUnallocatedSession = 15
|
||||||
|
InvalidProtocol = 16
|
||||||
|
BufferError = 17
|
||||||
|
DeviceIsBusy = 18
|
||||||
FirmwareError = 99
|
FirmwareError = 99
|
||||||
|
|
||||||
|
|
||||||
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
|
|||||||
Pass = 2
|
Pass = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ThpMessageType(IntEnum):
|
||||||
|
ThpCreateNewSession = 1000
|
||||||
|
ThpPairingRequest = 1006
|
||||||
|
ThpPairingRequestApproved = 1007
|
||||||
|
ThpSelectMethod = 1008
|
||||||
|
ThpPairingPreparationsFinished = 1009
|
||||||
|
ThpCredentialRequest = 1010
|
||||||
|
ThpCredentialResponse = 1011
|
||||||
|
ThpEndRequest = 1012
|
||||||
|
ThpEndResponse = 1013
|
||||||
|
ThpCodeEntryCommitment = 1016
|
||||||
|
ThpCodeEntryChallenge = 1017
|
||||||
|
ThpCodeEntryCpaceTrezor = 1018
|
||||||
|
ThpCodeEntryCpaceHostTag = 1019
|
||||||
|
ThpCodeEntrySecret = 1020
|
||||||
|
ThpQrCodeTag = 1024
|
||||||
|
ThpQrCodeSecret = 1025
|
||||||
|
ThpNfcTagHost = 1032
|
||||||
|
ThpNfcTagTrezor = 1033
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingMethod(IntEnum):
|
||||||
|
SkipPairing = 1
|
||||||
|
CodeEntry = 2
|
||||||
|
QrCode = 3
|
||||||
|
NFC = 4
|
||||||
|
|
||||||
|
|
||||||
class MessageType(IntEnum):
|
class MessageType(IntEnum):
|
||||||
Initialize = 0
|
Initialize = 0
|
||||||
Ping = 1
|
Ping = 1
|
||||||
@ -501,6 +533,8 @@ class MessageType(IntEnum):
|
|||||||
DebugLinkWatchLayout = 9006
|
DebugLinkWatchLayout = 9006
|
||||||
DebugLinkResetDebugEvents = 9007
|
DebugLinkResetDebugEvents = 9007
|
||||||
DebugLinkOptigaSetSecMax = 9008
|
DebugLinkOptigaSetSecMax = 9008
|
||||||
|
DebugLinkGetPairingInfo = 9009
|
||||||
|
DebugLinkPairingInfo = 9010
|
||||||
EthereumGetPublicKey = 450
|
EthereumGetPublicKey = 450
|
||||||
EthereumPublicKey = 451
|
EthereumPublicKey = 451
|
||||||
EthereumGetAddress = 56
|
EthereumGetAddress = 56
|
||||||
@ -4222,6 +4256,52 @@ class DebugLinkState(protobuf.MessageType):
|
|||||||
self.mnemonic_type = mnemonic_type
|
self.mnemonic_type = mnemonic_type
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLinkGetPairingInfo(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 9009
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: Optional["bytes"] = None,
|
||||||
|
handshake_hash: Optional["bytes"] = None,
|
||||||
|
nfc_secret_host: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.handshake_hash = handshake_hash
|
||||||
|
self.nfc_secret_host = nfc_secret_host
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLinkPairingInfo(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 9010
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
|
||||||
|
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
|
||||||
|
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: Optional["bytes"] = None,
|
||||||
|
handshake_hash: Optional["bytes"] = None,
|
||||||
|
code_entry_code: Optional["int"] = None,
|
||||||
|
code_qr_code: Optional["bytes"] = None,
|
||||||
|
nfc_secret_trezor: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.handshake_hash = handshake_hash
|
||||||
|
self.code_entry_code = code_entry_code
|
||||||
|
self.code_qr_code = code_qr_code
|
||||||
|
self.nfc_secret_trezor = nfc_secret_trezor
|
||||||
|
|
||||||
|
|
||||||
class DebugLinkStop(protobuf.MessageType):
|
class DebugLinkStop(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = 103
|
MESSAGE_WIRE_TYPE = 103
|
||||||
|
|
||||||
@ -7976,8 +8056,68 @@ class TezosManagerTransfer(protobuf.MessageType):
|
|||||||
self.amount = amount
|
self.amount = amount
|
||||||
|
|
||||||
|
|
||||||
class ThpCredentialMetadata(protobuf.MessageType):
|
class ThpDeviceProperties(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = None
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
|
||||||
|
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
|
||||||
|
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||||
|
internal_model: Optional["str"] = None,
|
||||||
|
model_variant: Optional["int"] = None,
|
||||||
|
protocol_version_major: Optional["int"] = None,
|
||||||
|
protocol_version_minor: Optional["int"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||||
|
self.internal_model = internal_model
|
||||||
|
self.model_variant = model_variant
|
||||||
|
self.protocol_version_major = protocol_version_major
|
||||||
|
self.protocol_version_minor = protocol_version_minor
|
||||||
|
|
||||||
|
|
||||||
|
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_pairing_credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_pairing_credential = host_pairing_credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCreateNewSession(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1000
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
passphrase: Optional["str"] = None,
|
||||||
|
on_device: Optional["bool"] = None,
|
||||||
|
derive_cardano: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.passphrase = passphrase
|
||||||
|
self.on_device = on_device
|
||||||
|
self.derive_cardano = derive_cardano
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1006
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||||
}
|
}
|
||||||
@ -7990,6 +8130,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
|
|||||||
self.host_name = host_name
|
self.host_name = host_name
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingRequestApproved(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1007
|
||||||
|
|
||||||
|
|
||||||
|
class ThpSelectMethod(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1008
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
selected_pairing_method: Optional["ThpPairingMethod"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.selected_pairing_method = selected_pairing_method
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1009
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1016
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
commitment: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.commitment = commitment
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1017
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
challenge: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.challenge = challenge
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1018
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_trezor_public_key: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_trezor_public_key = cpace_trezor_public_key
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1019
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_host_public_key: Optional["bytes"] = None,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_host_public_key = cpace_host_public_key
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1020
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1024
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeSecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1025
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcTagHost(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1032
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcTagTrezor(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1033
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1010
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_static_pubkey: Optional["bytes"] = None,
|
||||||
|
autoconnect: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_static_pubkey = host_static_pubkey
|
||||||
|
self.autoconnect = autoconnect
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1011
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
trezor_static_pubkey: Optional["bytes"] = None,
|
||||||
|
credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.trezor_static_pubkey = trezor_static_pubkey
|
||||||
|
self.credential = credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1012
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1013
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialMetadata(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_name: Optional["str"] = None,
|
||||||
|
autoconnect: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_name = host_name
|
||||||
|
self.autoconnect = autoconnect
|
||||||
|
|
||||||
|
|
||||||
class ThpPairingCredential(protobuf.MessageType):
|
class ThpPairingCredential(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = None
|
MESSAGE_WIRE_TYPE = None
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
|
@ -22,6 +22,7 @@ import typing as t
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ..client import ProtocolVersion
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import DeviceIsBusy, Transport, TransportException
|
from . import DeviceIsBusy, Transport, TransportException
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ TREZORD_HOST = "http://127.0.0.1:21325"
|
|||||||
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
||||||
|
|
||||||
TREZORD_VERSION_MODERN = (2, 0, 25)
|
TREZORD_VERSION_MODERN = (2, 0, 25)
|
||||||
|
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
|
||||||
|
|
||||||
CONNECTION = requests.Session()
|
CONNECTION = requests.Session()
|
||||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||||
@ -66,6 +68,44 @@ def is_legacy_bridge() -> bool:
|
|||||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||||
|
|
||||||
|
|
||||||
|
def supports_protocolV2() -> bool:
|
||||||
|
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
|
||||||
|
|
||||||
|
|
||||||
|
def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||||
|
from .. import mapping, messages
|
||||||
|
from ..messages import FailureType
|
||||||
|
|
||||||
|
protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||||
|
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||||
|
transport.open()
|
||||||
|
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
|
||||||
|
response = transport.read_chunk()
|
||||||
|
response_type = int.from_bytes(response[:2], "big")
|
||||||
|
response_data = response[2:]
|
||||||
|
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||||
|
if isinstance(response, messages.Failure):
|
||||||
|
if response.code == FailureType.InvalidProtocol:
|
||||||
|
LOG.debug("Protocol V2 detected")
|
||||||
|
protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||||
|
|
||||||
|
return protocol_version
|
||||||
|
|
||||||
|
|
||||||
|
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||||
|
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
|
||||||
|
if not is_valid:
|
||||||
|
LOG.warning("Detected unsupported Bridge transport!")
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
|
||||||
|
def filter_invalid_bridge_transports(
|
||||||
|
transports: t.Iterable["BridgeTransport"],
|
||||||
|
) -> t.Sequence["BridgeTransport"]:
|
||||||
|
"""Filters out invalid bridge transports. Keeps only valid ones."""
|
||||||
|
return [t for t in transports if _is_transport_valid(t)]
|
||||||
|
|
||||||
|
|
||||||
class BridgeHandle:
|
class BridgeHandle:
|
||||||
def __init__(self, transport: "BridgeTransport") -> None:
|
def __init__(self, transport: "BridgeTransport") -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
@ -160,9 +200,12 @@ class BridgeTransport(Transport):
|
|||||||
) -> t.Iterable["BridgeTransport"]:
|
) -> t.Iterable["BridgeTransport"]:
|
||||||
try:
|
try:
|
||||||
legacy = is_legacy_bridge()
|
legacy = is_legacy_bridge()
|
||||||
return [
|
return filter_invalid_bridge_transports(
|
||||||
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
|
[
|
||||||
]
|
BridgeTransport(dev, legacy)
|
||||||
|
for dev in call_bridge("enumerate").json()
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from .. import exceptions, messages, models
|
|||||||
from ..client import MAX_PIN_LENGTH
|
from ..client import MAX_PIN_LENGTH
|
||||||
from ..protobuf import MessageType
|
from ..protobuf import MessageType
|
||||||
from .thp.protocol_v1 import ProtocolV1Channel
|
from .thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
from .thp.protocol_v2 import ProtocolV2Channel
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..client import TrezorClient
|
||||||
@ -235,3 +236,50 @@ def derive_seed(session: Session, passphrase: str | object) -> None:
|
|||||||
_passphrase_ack=ack,
|
_passphrase_ack=ack,
|
||||||
)
|
)
|
||||||
session.refresh_features()
|
session.refresh_features()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionV2(Session):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls,
|
||||||
|
client: TrezorClient,
|
||||||
|
passphrase: str | None,
|
||||||
|
derive_cardano: bool,
|
||||||
|
session_id: int = 0,
|
||||||
|
) -> SessionV2:
|
||||||
|
assert isinstance(client.protocol, ProtocolV2Channel)
|
||||||
|
session = cls(client, session_id.to_bytes(1, "big"))
|
||||||
|
session.call(
|
||||||
|
messages.ThpCreateNewSession(
|
||||||
|
passphrase=passphrase, derive_cardano=derive_cardano
|
||||||
|
),
|
||||||
|
expect=messages.Success,
|
||||||
|
)
|
||||||
|
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||||
|
return session
|
||||||
|
|
||||||
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||||
|
from ..debuglink import TrezorClientDebugLink
|
||||||
|
|
||||||
|
super().__init__(client, id)
|
||||||
|
assert isinstance(client.protocol, ProtocolV2Channel)
|
||||||
|
|
||||||
|
helper_debug = None
|
||||||
|
if isinstance(client, TrezorClientDebugLink):
|
||||||
|
helper_debug = client.debug
|
||||||
|
self.channel: ProtocolV2Channel = client.protocol.get_channel(helper_debug)
|
||||||
|
self.update_id_and_sid(id)
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
LOG.debug("writing message %s", type(msg))
|
||||||
|
self.channel.write(self.sid, msg)
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
msg = self.channel.read(self.sid)
|
||||||
|
LOG.debug("reading message %s", type(msg))
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def update_id_and_sid(self, id: bytes) -> None:
|
||||||
|
self._id = id
|
||||||
|
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
||||||
|
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# from storage.cache_thp import ChannelCache
|
||||||
|
# from trezor import log
|
||||||
|
# from trezor.wire.thp import ThpError
|
||||||
|
|
||||||
|
|
||||||
|
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
||||||
|
# """
|
||||||
|
# Checks if:
|
||||||
|
# - an ACK message is expected
|
||||||
|
# - the received ACK message acknowledges correct sequence number (bit)
|
||||||
|
# """
|
||||||
|
# if not _is_ack_expected(cache):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# if not _has_ack_correct_sync_bit(cache, ack_bit):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# return True
|
||||||
|
|
||||||
|
|
||||||
|
# def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||||
|
# is_expected: bool = not is_sending_allowed(cache)
|
||||||
|
# if __debug__ and not is_expected:
|
||||||
|
# log.debug(__name__, "Received unexpected ACK message")
|
||||||
|
# return is_expected
|
||||||
|
|
||||||
|
|
||||||
|
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||||
|
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||||
|
# if __debug__ and not is_correct:
|
||||||
|
# log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||||
|
# return is_correct
|
||||||
|
|
||||||
|
|
||||||
|
# def is_sending_allowed(cache: ChannelCache) -> bool:
|
||||||
|
# """
|
||||||
|
# Checks whether sending a message in the provided channel is allowed.
|
||||||
|
|
||||||
|
# Note: Sending a message in a channel before receipt of ACK message for the previously
|
||||||
|
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
|
||||||
|
# """
|
||||||
|
# return bool(cache.sync >> 7)
|
||||||
|
|
||||||
|
|
||||||
|
# def get_send_seq_bit(cache: ChannelCache) -> int:
|
||||||
|
# """
|
||||||
|
# Returns the sequential number (bit) of the next message to be sent
|
||||||
|
# in the provided channel.
|
||||||
|
# """
|
||||||
|
# return (cache.sync & 0x20) >> 5
|
||||||
|
|
||||||
|
|
||||||
|
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
|
||||||
|
# """
|
||||||
|
# Returns the (expected) sequential number (bit) of the next message
|
||||||
|
# to be received in the provided channel.
|
||||||
|
# """
|
||||||
|
# return (cache.sync & 0x40) >> 6
|
||||||
|
|
||||||
|
|
||||||
|
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
|
||||||
|
# """
|
||||||
|
# Set the flag whether sending a message in this channel is allowed or not.
|
||||||
|
# """
|
||||||
|
# cache.sync &= 0x7F
|
||||||
|
# if sending_allowed:
|
||||||
|
# cache.sync |= 0x80
|
||||||
|
|
||||||
|
|
||||||
|
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||||
|
# """
|
||||||
|
# Set the expected sequential number (bit) of the next message to be received
|
||||||
|
# in the provided channel
|
||||||
|
# """
|
||||||
|
# if __debug__:
|
||||||
|
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
|
||||||
|
# if seq_bit not in (0, 1):
|
||||||
|
# raise ThpError("Unexpected receive sync bit")
|
||||||
|
|
||||||
|
# # set second bit to "seq_bit" value
|
||||||
|
# cache.sync &= 0xBF
|
||||||
|
# if seq_bit:
|
||||||
|
# cache.sync |= 0x40
|
||||||
|
|
||||||
|
|
||||||
|
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||||
|
# if seq_bit not in (0, 1):
|
||||||
|
# raise ThpError("Unexpected send seq bit")
|
||||||
|
# if __debug__:
|
||||||
|
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
|
||||||
|
# # set third bit to "seq_bit" value
|
||||||
|
# cache.sync &= 0xDF
|
||||||
|
# if seq_bit:
|
||||||
|
# cache.sync |= 0x20
|
||||||
|
|
||||||
|
|
||||||
|
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
|
||||||
|
# """
|
||||||
|
# Set the sequential bit of the "next message to be send" to the opposite value,
|
||||||
|
# i.e. 1 -> 0 and 0 -> 1
|
||||||
|
# """
|
||||||
|
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
|
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelData:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
protocol_version_major: int,
|
||||||
|
protocol_version_minor: int,
|
||||||
|
transport_path: str,
|
||||||
|
channel_id: int,
|
||||||
|
key_request: bytes,
|
||||||
|
key_response: bytes,
|
||||||
|
nonce_request: int,
|
||||||
|
nonce_response: int,
|
||||||
|
sync_bit_send: int,
|
||||||
|
sync_bit_receive: int,
|
||||||
|
handshake_hash: bytes,
|
||||||
|
) -> None:
|
||||||
|
self.protocol_version_major: int = protocol_version_major
|
||||||
|
self.protocol_version_minor: int = protocol_version_minor
|
||||||
|
self.transport_path: str = transport_path
|
||||||
|
self.channel_id: int = channel_id
|
||||||
|
self.key_request: str = hexlify(key_request).decode()
|
||||||
|
self.key_response: str = hexlify(key_response).decode()
|
||||||
|
self.nonce_request: int = nonce_request
|
||||||
|
self.nonce_response: int = nonce_response
|
||||||
|
self.sync_bit_receive: int = sync_bit_receive
|
||||||
|
self.sync_bit_send: int = sync_bit_send
|
||||||
|
self.handshake_hash: str = hexlify(handshake_hash).decode()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"protocol_version_major": self.protocol_version_major,
|
||||||
|
"protocol_version_minor": self.protocol_version_minor,
|
||||||
|
"transport_path": self.transport_path,
|
||||||
|
"channel_id": self.channel_id,
|
||||||
|
"key_request": self.key_request,
|
||||||
|
"key_response": self.key_response,
|
||||||
|
"nonce_request": self.nonce_request,
|
||||||
|
"nonce_response": self.nonce_response,
|
||||||
|
"sync_bit_send": self.sync_bit_send,
|
||||||
|
"sync_bit_receive": self.sync_bit_receive,
|
||||||
|
"handshake_hash": self.handshake_hash,
|
||||||
|
}
|
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from .channel_data import ChannelData
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db: "ChannelDatabase | None" = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_db() -> ChannelDatabase:
|
||||||
|
if db is None:
|
||||||
|
set_channel_database(should_not_store=True)
|
||||||
|
assert db is not None
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from .protocol_and_channel import Channel
|
||||||
|
from .protocol_v2 import ProtocolV2Channel
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDatabase:
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||||
|
def clear_stored_channels(self) -> None: ...
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: Channel): ...
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChannelDatabase(ChannelDatabase):
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: Channel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChannelDatabase(ChannelDatabase):
|
||||||
|
def __init__(self, data_path: str) -> None:
|
||||||
|
self.data_path = data_path
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
dicts = self._read_all_channels()
|
||||||
|
return [dict_to_channel_data(d) for d in dicts]
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
LOG.debug("Clearing contents of %s", self.data_path)
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
try:
|
||||||
|
os.remove(self.data_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||||
|
|
||||||
|
def _read_all_channels(self) -> t.List:
|
||||||
|
ensure_file_exists(self.data_path)
|
||||||
|
with open(self.data_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
LOG.debug("saving all channels")
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump(channels, f, indent=4)
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: ProtocolV2Channel):
|
||||||
|
|
||||||
|
LOG.debug("save channel")
|
||||||
|
channels = self._read_all_channels()
|
||||||
|
transport_path = new_channel.transport.get_path()
|
||||||
|
|
||||||
|
# If the channel is found in database: replace the old entry by the new
|
||||||
|
for i, channel in enumerate(channels):
|
||||||
|
if channel["transport_path"] == transport_path:
|
||||||
|
LOG.debug("Modified channel entry for %s", transport_path)
|
||||||
|
channels[i] = new_channel.get_channel_data().to_dict()
|
||||||
|
self._save_all_channels(channels)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Channel was not found: add a new channel entry
|
||||||
|
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||||
|
channels.append(new_channel.get_channel_data().to_dict())
|
||||||
|
self._save_all_channels(channels)
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
LOG.debug(
|
||||||
|
"Removing channel with path %s from the channel database.",
|
||||||
|
transport_path,
|
||||||
|
)
|
||||||
|
channels = self._read_all_channels()
|
||||||
|
remaining_channels = [
|
||||||
|
ch for ch in channels if ch["transport_path"] != transport_path
|
||||||
|
]
|
||||||
|
self._save_all_channels(remaining_channels)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version_major=dict["protocol_version_minor"],
|
||||||
|
protocol_version_minor=dict["protocol_version_major"],
|
||||||
|
transport_path=dict["transport_path"],
|
||||||
|
channel_id=dict["channel_id"],
|
||||||
|
key_request=bytes.fromhex(dict["key_request"]),
|
||||||
|
key_response=bytes.fromhex(dict["key_response"]),
|
||||||
|
nonce_request=dict["nonce_request"],
|
||||||
|
nonce_response=dict["nonce_response"],
|
||||||
|
sync_bit_send=dict["sync_bit_send"],
|
||||||
|
sync_bit_receive=dict["sync_bit_receive"],
|
||||||
|
handshake_hash=bytes.fromhex(dict["handshake_hash"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_file_exists(file_path: str) -> None:
|
||||||
|
LOG.debug("checking if file %s exists", file_path)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
|
||||||
|
|
||||||
|
def set_channel_database(should_not_store: bool):
|
||||||
|
global db
|
||||||
|
if should_not_store:
|
||||||
|
db = DummyChannelDatabase()
|
||||||
|
else:
|
||||||
|
from platformdirs import user_cache_dir
|
||||||
|
|
||||||
|
APP_NAME = "@trezor" # TODO
|
||||||
|
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||||
|
|
||||||
|
db = JsonChannelDatabase(DATA_PATH)
|
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import zlib
|
||||||
|
|
||||||
|
CHECKSUM_LENGTH = 4
|
||||||
|
|
||||||
|
|
||||||
|
def compute(data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Returns a CRC-32 checksum of the provided `data`.
|
||||||
|
"""
|
||||||
|
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(checksum: bytes, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the CRC-32 checksum of the `data` is the same
|
||||||
|
as the checksum provided in `checksum`.
|
||||||
|
"""
|
||||||
|
data_checksum = compute(data)
|
||||||
|
return checksum == data_checksum
|
63
python/src/trezorlib/transport/thp/control_byte.py
Normal file
63
python/src/trezorlib/transport/thp/control_byte.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
CODEC_V1 = 0x3F
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
HANDSHAKE_INIT_REQ = 0x00
|
||||||
|
HANDSHAKE_INIT_RES = 0x01
|
||||||
|
HANDSHAKE_COMP_REQ = 0x02
|
||||||
|
HANDSHAKE_COMP_RES = 0x03
|
||||||
|
ENCRYPTED_TRANSPORT = 0x04
|
||||||
|
|
||||||
|
CONTINUATION_PACKET_MASK = 0x80
|
||||||
|
ACK_MASK = 0xF7
|
||||||
|
DATA_MASK = 0xE7
|
||||||
|
|
||||||
|
ACK_MESSAGE = 0x20
|
||||||
|
_ERROR = 0x42
|
||||||
|
CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x41
|
||||||
|
|
||||||
|
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
|
TREZOR_STATE_PAIRED = b"\x01"
|
||||||
|
|
||||||
|
|
||||||
|
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
|
||||||
|
if seq_bit == 0:
|
||||||
|
return ctrl_byte & 0xEF
|
||||||
|
if seq_bit == 1:
|
||||||
|
return ctrl_byte | 0x10
|
||||||
|
raise Exception("Unexpected sequence bit")
|
||||||
|
|
||||||
|
|
||||||
|
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
|
||||||
|
if ack_bit == 0:
|
||||||
|
return ctrl_byte & 0xF7
|
||||||
|
if ack_bit == 1:
|
||||||
|
return ctrl_byte | 0x08
|
||||||
|
raise Exception("Unexpected acknowledgement bit")
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_bit(ctrl_byte: int) -> int:
|
||||||
|
return (ctrl_byte & 0x10) >> 4
|
||||||
|
|
||||||
|
|
||||||
|
def is_ack(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
|
def is_error(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte == _ERROR
|
||||||
|
|
||||||
|
|
||||||
|
def is_continuation(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
|
||||||
|
|
||||||
|
|
||||||
|
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
|
||||||
|
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
|
||||||
|
|
||||||
|
|
||||||
|
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ
|
40
python/src/trezorlib/transport/thp/cpace.py
Normal file
40
python/src/trezorlib/transport/thp/cpace.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import typing as t
|
||||||
|
from hashlib import sha512
|
||||||
|
|
||||||
|
from . import curve25519
|
||||||
|
|
||||||
|
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
|
||||||
|
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
|
||||||
|
|
||||||
|
|
||||||
|
class Cpace:
|
||||||
|
"""
|
||||||
|
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
|
||||||
|
"""
|
||||||
|
|
||||||
|
random_bytes: t.Callable[[int], bytes]
|
||||||
|
|
||||||
|
def __init__(self, handshake_hash: bytes) -> None:
|
||||||
|
self.handshake_hash: bytes = handshake_hash
|
||||||
|
self.shared_secret: bytes
|
||||||
|
self.host_private_key: bytes
|
||||||
|
self.host_public_key: bytes
|
||||||
|
|
||||||
|
def generate_keys_and_secret(
|
||||||
|
self, code_code_entry: bytes, trezor_public_key: bytes
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
|
||||||
|
"""
|
||||||
|
sha_ctx = sha512(_PREFIX)
|
||||||
|
sha_ctx.update(code_code_entry)
|
||||||
|
sha_ctx.update(_PADDING)
|
||||||
|
sha_ctx.update(self.handshake_hash)
|
||||||
|
sha_ctx.update(b"\x00")
|
||||||
|
pregenerator = sha_ctx.digest()[:32]
|
||||||
|
generator = curve25519.elligator2(pregenerator)
|
||||||
|
self.host_private_key = self.random_bytes(32)
|
||||||
|
self.host_public_key = curve25519.multiply(self.host_private_key, generator)
|
||||||
|
self.shared_secret = curve25519.multiply(
|
||||||
|
self.host_private_key, trezor_public_key
|
||||||
|
)
|
159
python/src/trezorlib/transport/thp/curve25519.py
Normal file
159
python/src/trezorlib/transport/thp/curve25519.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
p = 2**255 - 19
|
||||||
|
J = 486662
|
||||||
|
|
||||||
|
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
|
||||||
|
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
|
||||||
|
a24 = 121666 # (J + 2) // 4
|
||||||
|
|
||||||
|
|
||||||
|
def decode_scalar(scalar: bytes) -> int:
|
||||||
|
# decodeScalar25519 from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
|
||||||
|
if len(scalar) != 32:
|
||||||
|
raise ValueError("Invalid length of scalar")
|
||||||
|
|
||||||
|
array = bytearray(scalar)
|
||||||
|
array[0] &= 248
|
||||||
|
array[31] &= 127
|
||||||
|
array[31] |= 64
|
||||||
|
|
||||||
|
return int.from_bytes(array, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_coordinate(coordinate: bytes) -> int:
|
||||||
|
# decodeUCoordinate from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
if len(coordinate) != 32:
|
||||||
|
raise ValueError("Invalid length of coordinate")
|
||||||
|
|
||||||
|
array = bytearray(coordinate)
|
||||||
|
array[-1] &= 0x7F
|
||||||
|
return int.from_bytes(array, "little") % p
|
||||||
|
|
||||||
|
|
||||||
|
def encode_coordinate(coordinate: int) -> bytes:
|
||||||
|
# encodeUCoordinate from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
return coordinate.to_bytes(32, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def get_private_key(secret: bytes) -> bytes:
|
||||||
|
return decode_scalar(secret).to_bytes(32, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_key(private_key: bytes) -> bytes:
|
||||||
|
base_point = int.to_bytes(9, 32, "little")
|
||||||
|
return multiply(private_key, base_point)
|
||||||
|
|
||||||
|
|
||||||
|
def multiply(private_scalar: bytes, public_point: bytes):
|
||||||
|
# X25519 from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
|
||||||
|
def ladder_operation(
|
||||||
|
x1: int, x2: int, z2: int, x3: int, z3: int
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
|
||||||
|
# (x4, z4) = 2 * (x2, z2)
|
||||||
|
# (x5, z5) = (x2, z2) + (x3, z3)
|
||||||
|
# where (x1, 1) = (x3, z3) - (x2, z2)
|
||||||
|
|
||||||
|
a = (x2 + z2) % p
|
||||||
|
aa = (a * a) % p
|
||||||
|
b = (x2 - z2) % p
|
||||||
|
bb = (b * b) % p
|
||||||
|
e = (aa - bb) % p
|
||||||
|
c = (x3 + z3) % p
|
||||||
|
d = (x3 - z3) % p
|
||||||
|
da = (d * a) % p
|
||||||
|
cb = (c * b) % p
|
||||||
|
t0 = (da + cb) % p
|
||||||
|
x5 = (t0 * t0) % p
|
||||||
|
t1 = (da - cb) % p
|
||||||
|
t2 = (t1 * t1) % p
|
||||||
|
z5 = (x1 * t2) % p
|
||||||
|
x4 = (aa * bb) % p
|
||||||
|
t3 = (a24 * e) % p
|
||||||
|
t4 = (bb + t3) % p
|
||||||
|
z4 = (e * t4) % p
|
||||||
|
|
||||||
|
return x4, z4, x5, z5
|
||||||
|
|
||||||
|
def conditional_swap(first: int, second: int, condition: int):
|
||||||
|
# Returns (second, first) if condition is true and (first, second) otherwise
|
||||||
|
# Must be implemented in a way that it is constant time
|
||||||
|
true_mask = -condition
|
||||||
|
false_mask = ~true_mask
|
||||||
|
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
|
||||||
|
first & true_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
k = decode_scalar(private_scalar)
|
||||||
|
u = decode_coordinate(public_point)
|
||||||
|
|
||||||
|
x_1 = u
|
||||||
|
x_2 = 1
|
||||||
|
z_2 = 0
|
||||||
|
x_3 = u
|
||||||
|
z_3 = 1
|
||||||
|
swap = 0
|
||||||
|
|
||||||
|
for i in reversed(range(256)):
|
||||||
|
bit = (k >> i) & 1
|
||||||
|
swap = bit ^ swap
|
||||||
|
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||||
|
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||||
|
swap = bit
|
||||||
|
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
|
||||||
|
|
||||||
|
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||||
|
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||||
|
|
||||||
|
x = pow(z_2, p - 2, p) * x_2 % p
|
||||||
|
return encode_coordinate(x)
|
||||||
|
|
||||||
|
|
||||||
|
def elligator2(point: bytes) -> bytes:
|
||||||
|
# map_to_curve_elligator2_curve25519 from
|
||||||
|
# https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt
|
||||||
|
|
||||||
|
def conditional_move(first: int, second: int, condition: bool):
|
||||||
|
# Returns second if condition is true and first otherwise
|
||||||
|
# Must be implemented in a way that it is constant time
|
||||||
|
true_mask = -condition
|
||||||
|
false_mask = ~true_mask
|
||||||
|
return (first & false_mask) | (second & true_mask)
|
||||||
|
|
||||||
|
u = decode_coordinate(point)
|
||||||
|
tv1 = (u * u) % p
|
||||||
|
tv1 = (2 * tv1) % p
|
||||||
|
xd = (tv1 + 1) % p
|
||||||
|
x1n = (-J) % p
|
||||||
|
tv2 = (xd * xd) % p
|
||||||
|
gxd = (tv2 * xd) % p
|
||||||
|
gx1 = (J * tv1) % p
|
||||||
|
gx1 = (gx1 * x1n) % p
|
||||||
|
gx1 = (gx1 + tv2) % p
|
||||||
|
gx1 = (gx1 * x1n) % p
|
||||||
|
tv3 = (gxd * gxd) % p
|
||||||
|
tv2 = (tv3 * tv3) % p
|
||||||
|
tv3 = (tv3 * gxd) % p
|
||||||
|
tv3 = (tv3 * gx1) % p
|
||||||
|
tv2 = (tv2 * tv3) % p
|
||||||
|
y11 = pow(tv2, c4, p)
|
||||||
|
y11 = (y11 * tv3) % p
|
||||||
|
y12 = (y11 * c3) % p
|
||||||
|
tv2 = (y11 * y11) % p
|
||||||
|
tv2 = (tv2 * gxd) % p
|
||||||
|
e1 = tv2 == gx1
|
||||||
|
y1 = conditional_move(y12, y11, e1)
|
||||||
|
x2n = (x1n * tv1) % p
|
||||||
|
tv2 = (y1 * y1) % p
|
||||||
|
tv2 = (tv2 * gxd) % p
|
||||||
|
e3 = tv2 == gx1
|
||||||
|
xn = conditional_move(x2n, x1n, e3)
|
||||||
|
x = xn * pow(xd, p - 2, p) % p
|
||||||
|
return encode_coordinate(x)
|
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
CODEC_V1 = 0x3F
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
HANDSHAKE_INIT_REQ = 0x00
|
||||||
|
HANDSHAKE_INIT_RES = 0x01
|
||||||
|
HANDSHAKE_COMP_REQ = 0x02
|
||||||
|
HANDSHAKE_COMP_RES = 0x03
|
||||||
|
ENCRYPTED_TRANSPORT = 0x04
|
||||||
|
|
||||||
|
CONTINUATION_PACKET_MASK = 0x80
|
||||||
|
ACK_MASK = 0xF7
|
||||||
|
DATA_MASK = 0xE7
|
||||||
|
|
||||||
|
ACK_MESSAGE = 0x20
|
||||||
|
_ERROR = 0x42
|
||||||
|
CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x41
|
||||||
|
|
||||||
|
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
|
TREZOR_STATE_PAIRED = b"\x01"
|
||||||
|
|
||||||
|
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHeader:
|
||||||
|
format_str_init = ">BHH"
|
||||||
|
format_str_cont = ">BH"
|
||||||
|
|
||||||
|
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||||
|
self.ctrl_byte = ctrl_byte
|
||||||
|
self.cid = cid
|
||||||
|
self.data_length = length
|
||||||
|
|
||||||
|
def to_bytes_init(self) -> bytes:
|
||||||
|
return struct.pack(
|
||||||
|
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_bytes_cont(self) -> bytes:
|
||||||
|
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
|
||||||
|
|
||||||
|
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_init,
|
||||||
|
buffer,
|
||||||
|
buffer_offset,
|
||||||
|
self.ctrl_byte,
|
||||||
|
self.cid,
|
||||||
|
self.data_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_ack(self) -> bool:
|
||||||
|
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||||
|
|
||||||
|
def is_channel_allocation_response(self):
|
||||||
|
return (
|
||||||
|
self.cid == BROADCAST_CHANNEL_ID
|
||||||
|
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_handshake_init_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
|
||||||
|
|
||||||
|
def is_handshake_comp_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
|
||||||
|
|
||||||
|
def is_encrypted_transport(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_error_header(cls, cid: int, length: int):
|
||||||
|
return cls(_ERROR, cid, length)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_channel_allocation_request_header(cls, length: int):
|
||||||
|
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)
|
490
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
490
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
import click
|
||||||
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||||
|
|
||||||
|
from ... import exceptions, messages, protobuf
|
||||||
|
from ...mapping import ProtobufMapping
|
||||||
|
from .. import Transport
|
||||||
|
from ..thp import checksum, curve25519, thp_io
|
||||||
|
from ..thp.channel_data import ChannelData
|
||||||
|
from ..thp.checksum import CHECKSUM_LENGTH
|
||||||
|
from ..thp.message_header import MessageHeader
|
||||||
|
from . import control_byte
|
||||||
|
from .channel_database import ChannelDatabase, get_channel_db
|
||||||
|
from .protocol_and_channel import Channel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SESSION_ID: int = 0
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from ...debuglink import DebugLink
|
||||||
|
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||||
|
hash = hashlib.sha256(val_1)
|
||||||
|
hash.update(val_2)
|
||||||
|
return hash.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def _hkdf(chaining_key: bytes, input: bytes):
|
||||||
|
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||||
|
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||||
|
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||||
|
ctx_output_2.update(b"\x02")
|
||||||
|
output_2 = ctx_output_2.digest()
|
||||||
|
return (output_1, output_2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||||
|
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||||
|
raise ValueError("Nonce overflow, terminate the channel")
|
||||||
|
return bytes(4) + nonce.to_bytes(8, "big")
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolV2Channel(Channel):
|
||||||
|
channel_id: int
|
||||||
|
channel_database: ChannelDatabase
|
||||||
|
key_request: bytes
|
||||||
|
key_response: bytes
|
||||||
|
nonce_request: int
|
||||||
|
nonce_response: int
|
||||||
|
sync_bit_send: int
|
||||||
|
sync_bit_receive: int
|
||||||
|
handshake_hash: bytes
|
||||||
|
|
||||||
|
_has_valid_channel: bool = False
|
||||||
|
_features: messages.Features | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: Transport,
|
||||||
|
mapping: ProtobufMapping,
|
||||||
|
channel_data: ChannelData | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_database: ChannelDatabase = get_channel_db()
|
||||||
|
super().__init__(transport, mapping)
|
||||||
|
if channel_data is not None:
|
||||||
|
self.channel_id = channel_data.channel_id
|
||||||
|
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||||
|
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||||
|
self.nonce_request = channel_data.nonce_request
|
||||||
|
self.nonce_response = channel_data.nonce_response
|
||||||
|
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||||
|
self.sync_bit_send = channel_data.sync_bit_send
|
||||||
|
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
|
||||||
|
self._has_valid_channel = True
|
||||||
|
|
||||||
|
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
|
||||||
|
if not self._has_valid_channel:
|
||||||
|
self._establish_new_channel(helper_debug)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_channel_data(self) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version_major=2,
|
||||||
|
protocol_version_minor=2,
|
||||||
|
transport_path=self.transport.get_path(),
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
key_request=self.key_request,
|
||||||
|
key_response=self.key_response,
|
||||||
|
nonce_request=self.nonce_request,
|
||||||
|
nonce_response=self.nonce_response,
|
||||||
|
sync_bit_receive=self.sync_bit_receive,
|
||||||
|
sync_bit_send=self.sync_bit_send,
|
||||||
|
handshake_hash=self.handshake_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
def read(self, session_id: int) -> t.Any:
|
||||||
|
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
if sid != session_id:
|
||||||
|
raise Exception("Received messsage on a different session.")
|
||||||
|
self.channel_database.save_channel(self)
|
||||||
|
return self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
|
def write(self, session_id: int, msg: t.Any) -> None:
|
||||||
|
msg_type, msg_data = self.mapping.encode(msg)
|
||||||
|
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||||
|
self.channel_database.save_channel(self)
|
||||||
|
|
||||||
|
def get_features(self) -> messages.Features:
|
||||||
|
if not self._has_valid_channel:
|
||||||
|
self._establish_new_channel()
|
||||||
|
if self._features is None:
|
||||||
|
self.update_features()
|
||||||
|
assert self._features is not None
|
||||||
|
return self._features
|
||||||
|
|
||||||
|
def update_features(self) -> None:
|
||||||
|
message = messages.GetFeatures()
|
||||||
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
self.session_id: int = DEFAULT_SESSION_ID
|
||||||
|
self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data)
|
||||||
|
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||||
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
features = self.mapping.decode(msg_type, msg_data)
|
||||||
|
if not isinstance(features, messages.Features):
|
||||||
|
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||||
|
self._features = features
|
||||||
|
|
||||||
|
def _send_message(
|
||||||
|
self,
|
||||||
|
message: protobuf.MessageType,
|
||||||
|
session_id: int = DEFAULT_SESSION_ID,
|
||||||
|
):
|
||||||
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
self._encrypt_and_write(session_id, message_type, message_data)
|
||||||
|
self._read_ack()
|
||||||
|
|
||||||
|
def _read_message(self, message_type: type[MT]) -> MT:
|
||||||
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
msg = self.mapping.decode(msg_type, msg_data)
|
||||||
|
assert isinstance(msg, message_type)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
|
||||||
|
self._reset_sync_bits()
|
||||||
|
self._do_channel_allocation()
|
||||||
|
self._do_handshake()
|
||||||
|
self._do_pairing(helper_debug)
|
||||||
|
|
||||||
|
def _reset_sync_bits(self) -> None:
|
||||||
|
self.sync_bit_send = 0
|
||||||
|
self.sync_bit_receive = 0
|
||||||
|
|
||||||
|
def _do_channel_allocation(self) -> None:
|
||||||
|
channel_allocation_nonce = os.urandom(8)
|
||||||
|
self._send_channel_allocation_request(channel_allocation_nonce)
|
||||||
|
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
|
||||||
|
self.channel_id = cid
|
||||||
|
self.device_properties = dp
|
||||||
|
|
||||||
|
def _send_channel_allocation_request(self, nonce: bytes):
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport,
|
||||||
|
MessageHeader.get_channel_allocation_request_header(12),
|
||||||
|
nonce,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_channel_allocation_response(
|
||||||
|
self, expected_nonce: bytes
|
||||||
|
) -> tuple[int, bytes]:
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not self._is_valid_channel_allocation_response(
|
||||||
|
header, payload, expected_nonce
|
||||||
|
):
|
||||||
|
raise Exception("Invalid channel allocation response.")
|
||||||
|
|
||||||
|
channel_id = int.from_bytes(payload[8:10], "big")
|
||||||
|
device_properties = payload[10:]
|
||||||
|
return (channel_id, device_properties)
|
||||||
|
|
||||||
|
def _do_handshake(
|
||||||
|
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
|
||||||
|
):
|
||||||
|
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||||
|
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||||
|
|
||||||
|
self._send_handshake_init_request(host_ephemeral_pubkey)
|
||||||
|
self._read_ack()
|
||||||
|
init_response = self._read_handshake_init_response()
|
||||||
|
|
||||||
|
trezor_ephemeral_pubkey = init_response[:32]
|
||||||
|
encrypted_trezor_static_pubkey = init_response[32:80]
|
||||||
|
noise_tag = init_response[80:96]
|
||||||
|
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||||
|
|
||||||
|
# TODO check noise_tag is valid
|
||||||
|
|
||||||
|
ck = self._send_handshake_completion_request(
|
||||||
|
host_ephemeral_pubkey,
|
||||||
|
host_ephemeral_privkey,
|
||||||
|
trezor_ephemeral_pubkey,
|
||||||
|
encrypted_trezor_static_pubkey,
|
||||||
|
credential,
|
||||||
|
host_static_privkey,
|
||||||
|
)
|
||||||
|
self._read_ack()
|
||||||
|
self._read_handshake_completion_response()
|
||||||
|
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||||
|
self.nonce_request = 0
|
||||||
|
self.nonce_response = 1
|
||||||
|
|
||||||
|
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
|
||||||
|
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_handshake_init_response(self) -> bytes:
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
self._send_ack_0()
|
||||||
|
|
||||||
|
if header.ctrl_byte == 0x42:
|
||||||
|
if payload == b"\x05":
|
||||||
|
raise exceptions.DeviceLockedException()
|
||||||
|
|
||||||
|
if not header.is_handshake_init_response():
|
||||||
|
LOG.debug("Received message is not a valid handshake init response message")
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a valid handshake init response message",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _send_handshake_completion_request(
|
||||||
|
self,
|
||||||
|
host_ephemeral_pubkey: bytes,
|
||||||
|
host_ephemeral_privkey: bytes,
|
||||||
|
trezor_ephemeral_pubkey: bytes,
|
||||||
|
encrypted_trezor_static_pubkey: bytes,
|
||||||
|
credential: bytes | None = None,
|
||||||
|
host_static_privkey: bytes | None = None,
|
||||||
|
) -> bytes:
|
||||||
|
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||||
|
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||||
|
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||||
|
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||||
|
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||||
|
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
PROTOCOL_NAME,
|
||||||
|
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
try:
|
||||||
|
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||||
|
IV_1, encrypted_trezor_static_pubkey, h
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(
|
||||||
|
f"Exception of type{type(e)}", err=True
|
||||||
|
) # TODO how to handle potential exceptions? Q for Matejcik
|
||||||
|
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||||
|
)
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||||
|
h = _sha256_of_two(h, tag_of_empty_string)
|
||||||
|
|
||||||
|
# TODO: search for saved credentials
|
||||||
|
if host_static_privkey is not None and credential is not None:
|
||||||
|
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
||||||
|
else:
|
||||||
|
credential = None
|
||||||
|
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||||
|
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||||
|
temp_host_static_pubkey = curve25519.get_public_key(
|
||||||
|
temp_host_static_privkey
|
||||||
|
)
|
||||||
|
host_static_privkey = temp_host_static_privkey
|
||||||
|
host_static_pubkey = temp_host_static_pubkey
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
|
||||||
|
)
|
||||||
|
msg_data = self.mapping.encode_without_wire_type(
|
||||||
|
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||||
|
host_pairing_credential=credential,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_payload[:-16])
|
||||||
|
ha_completion_req_header = MessageHeader(
|
||||||
|
0x12,
|
||||||
|
self.channel_id,
|
||||||
|
len(encrypted_host_static_pubkey)
|
||||||
|
+ len(encrypted_payload)
|
||||||
|
+ CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport,
|
||||||
|
ha_completion_req_header,
|
||||||
|
encrypted_host_static_pubkey + encrypted_payload,
|
||||||
|
)
|
||||||
|
self.handshake_hash = h
|
||||||
|
return ck
|
||||||
|
|
||||||
|
def _read_handshake_completion_response(self) -> None:
|
||||||
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
|
header, _ = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_handshake_comp_response():
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a valid handshake completion response",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
self._send_ack_1()
|
||||||
|
|
||||||
|
def _do_pairing(self, helper_debug: DebugLink | None):
|
||||||
|
|
||||||
|
self._send_message(messages.ThpPairingRequest())
|
||||||
|
self._read_message(messages.ButtonRequest)
|
||||||
|
self._send_message(messages.ButtonAck())
|
||||||
|
|
||||||
|
if helper_debug is not None:
|
||||||
|
helper_debug.press_yes()
|
||||||
|
|
||||||
|
self._read_message(messages.ThpPairingRequestApproved)
|
||||||
|
self._send_message(
|
||||||
|
messages.ThpSelectMethod(
|
||||||
|
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._read_message(messages.ThpEndResponse)
|
||||||
|
|
||||||
|
self._has_valid_channel = True
|
||||||
|
|
||||||
|
def _read_ack(self):
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
click.echo("Received message is not a valid ACK", err=True)
|
||||||
|
|
||||||
|
def _send_ack_0(self):
|
||||||
|
LOG.debug("sending ack 0")
|
||||||
|
header = MessageHeader(0x20, self.channel_id, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||||
|
|
||||||
|
def _send_ack_1(self):
|
||||||
|
LOG.debug("sending ack 1")
|
||||||
|
header = MessageHeader(0x28, self.channel_id, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||||
|
|
||||||
|
def _encrypt_and_write(
|
||||||
|
self,
|
||||||
|
session_id: int,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
ctrl_byte: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert self.key_request is not None
|
||||||
|
aes_ctx = AESGCM(self.key_request)
|
||||||
|
|
||||||
|
if ctrl_byte is None:
|
||||||
|
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||||
|
self.sync_bit_send = 1 - self.sync_bit_send
|
||||||
|
|
||||||
|
sid = session_id.to_bytes(1, "big")
|
||||||
|
msg_type = message_type.to_bytes(2, "big")
|
||||||
|
data = sid + msg_type + message_data
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||||
|
self.nonce_request += 1
|
||||||
|
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||||
|
header = MessageHeader(
|
||||||
|
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport, header, encrypted_message
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||||
|
header, raw_payload = self._read_until_valid_crc_check()
|
||||||
|
if control_byte.is_ack(header.ctrl_byte):
|
||||||
|
# TODO fix this recursion
|
||||||
|
return self.read_and_decrypt()
|
||||||
|
if control_byte.is_error(header.ctrl_byte):
|
||||||
|
# TODO check for different channel
|
||||||
|
err = _get_error_from_int(raw_payload[0])
|
||||||
|
raise Exception("Received ThpError: " + err)
|
||||||
|
if not header.is_encrypted_transport():
|
||||||
|
click.echo(
|
||||||
|
"Trying to decrypt not encrypted message! ("
|
||||||
|
+ hexlify(header.to_bytes_init() + raw_payload).decode()
|
||||||
|
+ ")",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not control_byte.is_ack(header.ctrl_byte):
|
||||||
|
LOG.debug(
|
||||||
|
"--> Get sequence bit %d %s %s",
|
||||||
|
control_byte.get_seq_bit(header.ctrl_byte),
|
||||||
|
"from control byte",
|
||||||
|
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
|
||||||
|
)
|
||||||
|
if control_byte.get_seq_bit(header.ctrl_byte):
|
||||||
|
self._send_ack_1()
|
||||||
|
else:
|
||||||
|
self._send_ack_0()
|
||||||
|
aes_ctx = AESGCM(self.key_response)
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||||
|
self.nonce_response += 1
|
||||||
|
|
||||||
|
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||||
|
session_id = message[0]
|
||||||
|
message_type = message[1:3]
|
||||||
|
message_data = message[3:]
|
||||||
|
return (
|
||||||
|
session_id,
|
||||||
|
int.from_bytes(message_type, "big"),
|
||||||
|
message_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_until_valid_crc_check(
|
||||||
|
self,
|
||||||
|
) -> t.Tuple[MessageHeader, bytes]:
|
||||||
|
is_valid = False
|
||||||
|
header, payload, chksum = thp_io.read(self.transport)
|
||||||
|
while not is_valid:
|
||||||
|
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||||
|
if not is_valid:
|
||||||
|
click.echo(
|
||||||
|
"Received a message with an invalid checksum:"
|
||||||
|
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
header, payload, chksum = thp_io.read(self.transport)
|
||||||
|
|
||||||
|
return header, payload
|
||||||
|
|
||||||
|
def _is_valid_channel_allocation_response(
|
||||||
|
self, header: MessageHeader, payload: bytes, original_nonce: bytes
|
||||||
|
) -> bool:
|
||||||
|
if not header.is_channel_allocation_response():
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a channel allocation response", err=True
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if len(payload) < 10:
|
||||||
|
click.echo("Invalid channel allocation response payload", err=True)
|
||||||
|
return False
|
||||||
|
if payload[:8] != original_nonce:
|
||||||
|
click.echo(
|
||||||
|
"Invalid channel allocation response payload (nonce mismatch)", err=True
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_error_from_int(error_code: int) -> str:
|
||||||
|
# TODO FIXME improve this (ThpErrorType)
|
||||||
|
if error_code == 1:
|
||||||
|
return "TRANSPORT BUSY"
|
||||||
|
if error_code == 2:
|
||||||
|
return "UNALLOCATED CHANNEL"
|
||||||
|
if error_code == 3:
|
||||||
|
return "DECRYPTION FAILED"
|
||||||
|
if error_code == 4:
|
||||||
|
return "INVALID DATA"
|
||||||
|
if error_code == 5:
|
||||||
|
return "DEVICE LOCKED"
|
||||||
|
raise Exception("Not Implemented error case")
|
97
python/src/trezorlib/transport/thp/thp_io.py
Normal file
97
python/src/trezorlib/transport/thp/thp_io.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import struct
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .. import Transport
|
||||||
|
from ..thp import checksum
|
||||||
|
from .message_header import MessageHeader
|
||||||
|
|
||||||
|
INIT_HEADER_LENGTH = 5
|
||||||
|
CONT_HEADER_LENGTH = 3
|
||||||
|
MAX_PAYLOAD_LEN = 60000
|
||||||
|
MESSAGE_TYPE_LENGTH = 2
|
||||||
|
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire_and_add_checksum(
|
||||||
|
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||||
|
data = transport_payload + chksum
|
||||||
|
write_payload_to_wire(transport, header, data)
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire(
|
||||||
|
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
transport.open()
|
||||||
|
buffer = bytearray(transport_payload)
|
||||||
|
if transport.CHUNK_SIZE is None:
|
||||||
|
transport.write_chunk(buffer)
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
|
||||||
|
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||||
|
transport.write_chunk(chunk)
|
||||||
|
|
||||||
|
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
|
||||||
|
while buffer:
|
||||||
|
chunk = (
|
||||||
|
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||||
|
transport.write_chunk(chunk)
|
||||||
|
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||||
|
|
||||||
|
|
||||||
|
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
|
||||||
|
"""
|
||||||
|
Reads from the given wire transport.
|
||||||
|
|
||||||
|
Returns `Tuple[MessageHeader, bytes, bytes]`:
|
||||||
|
1. `header` (`MessageHeader`): Header of the message.
|
||||||
|
2. `data` (`bytes`): Contents of the message (if any).
|
||||||
|
3. `checksum` (`bytes`): crc32 checksum of the header + data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
buffer = bytearray()
|
||||||
|
|
||||||
|
# Read header with first part of message data
|
||||||
|
header, first_chunk = read_first(transport)
|
||||||
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
|
# Read the rest of the message
|
||||||
|
while len(buffer) < header.data_length:
|
||||||
|
buffer.extend(read_next(transport, header.cid))
|
||||||
|
|
||||||
|
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||||
|
msg_data = buffer[:data_len]
|
||||||
|
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
|
||||||
|
|
||||||
|
return (header, msg_data, chksum)
|
||||||
|
|
||||||
|
|
||||||
|
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
|
||||||
|
chunk = transport.read_chunk()
|
||||||
|
try:
|
||||||
|
ctrl_byte, cid, data_length = struct.unpack(
|
||||||
|
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError("Cannot parse header")
|
||||||
|
|
||||||
|
data = chunk[INIT_HEADER_LENGTH:]
|
||||||
|
return MessageHeader(ctrl_byte, cid, data_length), data
|
||||||
|
|
||||||
|
|
||||||
|
def read_next(transport: Transport, cid: int) -> bytes:
|
||||||
|
chunk = transport.read_chunk()
|
||||||
|
ctrl_byte, read_cid = struct.unpack(
|
||||||
|
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
if ctrl_byte != CONTINUATION_PACKET:
|
||||||
|
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||||
|
if read_cid != cid:
|
||||||
|
raise RuntimeError("Continuation packet for different channel")
|
||||||
|
|
||||||
|
return chunk[CONT_HEADER_LENGTH:]
|
@ -76,9 +76,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# change PIN
|
# change PIN
|
||||||
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
|
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
|
||||||
client.set_input_flow(pin_input_flow(client, last_pin, new_pin))
|
session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
|
||||||
device.change_pin(client)
|
device.change_pin(client)
|
||||||
client.set_input_flow(None)
|
session.set_input_flow(None)
|
||||||
last_pin = new_pin
|
last_pin = new_pin
|
||||||
|
|
||||||
print(f"iteration {i}")
|
print(f"iteration {i}")
|
||||||
|
@ -22,7 +22,9 @@ import time
|
|||||||
import typing as t
|
import typing as t
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import cryptography
|
||||||
import pytest
|
import pytest
|
||||||
import xdist
|
import xdist
|
||||||
from _pytest.python import IdMaker
|
from _pytest.python import IdMaker
|
||||||
@ -314,11 +316,23 @@ def _client_unlocked(
|
|||||||
should_format = sd_marker.kwargs.get("formatted", True)
|
should_format = sd_marker.kwargs.get("formatted", True)
|
||||||
_raw_client.debug.erase_sd_card(format=should_format)
|
_raw_client.debug.erase_sd_card(format=should_format)
|
||||||
|
|
||||||
if _raw_client.is_invalidated:
|
while True:
|
||||||
_raw_client = _raw_client.get_new_client()
|
try:
|
||||||
session = _raw_client.get_seedless_session()
|
if _raw_client.is_invalidated:
|
||||||
wipe_device(session)
|
_raw_client = _raw_client.get_new_client()
|
||||||
# sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
session = _raw_client.get_seedless_session()
|
||||||
|
wipe_device(session)
|
||||||
|
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
||||||
|
break
|
||||||
|
except cryptography.exceptions.InvalidTag:
|
||||||
|
# Get a new client
|
||||||
|
_raw_client = _get_raw_client(request)
|
||||||
|
|
||||||
|
_raw_client.protocol = None
|
||||||
|
_raw_client.__init__(
|
||||||
|
transport=_raw_client.transport,
|
||||||
|
auto_interact=_raw_client.debug.allow_interactions,
|
||||||
|
)
|
||||||
|
|
||||||
if not _raw_client.features.bootloader_mode:
|
if not _raw_client.features.bootloader_mode:
|
||||||
_raw_client.refresh_features()
|
_raw_client.refresh_features()
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
from _pytest.nodes import Node
|
from _pytest.nodes import Node
|
||||||
from _pytest.outcomes import Failed
|
from _pytest.outcomes import Failed
|
||||||
|
|
||||||
|
from trezorlib.client import ProtocolVersion
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
|
|
||||||
from . import common
|
from . import common
|
||||||
|
Loading…
Reference in New Issue
Block a user