parent
87d6407d26
commit
11309cccd0
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class ChannelData:
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: bytes
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import mapping
|
||||
from ...mapping import ProtobufMapping
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel, ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .session import Session, SessionV1, SessionV2
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NewTrezorClient:
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
protocol: ProtocolAndChannel | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
|
||||
if protobuf_mapping is None:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
print("test B")
|
||||
|
||||
if protocol is None:
|
||||
print("test C")
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
cls, transport: NewTransport, channel_data: ChannelData
|
||||
) -> NewTrezorClient: ...
|
||||
|
||||
def get_session(
|
||||
self, passphrase: str = "", derive_cardano: bool = False
|
||||
) -> Session:
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
return SessionV1.new(self, passphrase, derive_cardano)
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
return SessionV2.new(self, passphrase, derive_cardano)
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def resume_session(self, session_id: bytes) -> Session:
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def _get_protocol(self) -> ProtocolAndChannel:
|
||||
|
||||
from ... import mapping, messages
|
||||
from ...messages import FailureType
|
||||
from .protocol_and_channel import ProtocolV1
|
||||
|
||||
self.transport.open()
|
||||
|
||||
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
|
||||
|
||||
protocol.write(messages.Initialize())
|
||||
|
||||
response = protocol.read()
|
||||
self.transport.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
print("test F1")
|
||||
if (
|
||||
response.code == FailureType.UnexpectedMessage
|
||||
and response.message == "Invalid protocol"
|
||||
):
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol = ProtocolV2(self.transport, self.mapping)
|
||||
return protocol
|
@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ...log import DUMP_BYTES
|
||||
from ...mapping import ProtobufMapping
|
||||
from .channel_data import ChannelData
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolAndChannel:
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
self.channel_keys = channel_keys
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
# def write(self, session_id: bytes, msg: t.Any) -> None: ...
|
||||
|
||||
# def read(self, session_id: bytes) -> t.Any: ...
|
||||
|
||||
def get_channel_keys(self) -> ChannelData: ...
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
print("wooooo")
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
print("wooooo")
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self) -> t.Tuple[int, bytes]:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class Channel:
|
||||
id: int
|
||||
channel_keys: ChannelData | None
|
||||
|
||||
def __init__(self, id: int, keys: ChannelData) -> None:
|
||||
self.id = id
|
||||
self.channel_keys = keys
|
||||
|
||||
def read(self) -> t.Any: ...
|
||||
def write(self, msg: t.Any) -> None: ...
|
@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.packet_header import PacketHeader
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import Channel, ProtocolAndChannel
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MANAGEMENT_SESSION_ID: int = 0
|
||||
|
||||
|
||||
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||
hash = hashlib.sha256(val_1)
|
||||
hash.update(val_2)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
def _hkdf(chaining_key: bytes, input: bytes):
|
||||
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||
ctx_output_2.update(b"\x02")
|
||||
output_2 = ctx_output_2.digest()
|
||||
return (output_1, output_2)
|
||||
|
||||
|
||||
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||
raise ValueError("Nonce overflow, terminate the channel")
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
||||
|
||||
|
||||
class ProtocolV2(ProtocolAndChannel):
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
has_valid_channel: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
) -> None:
|
||||
super().__init__(transport, mapping, channel_keys)
|
||||
self.channel: Channel | None = None
|
||||
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self.has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
return self
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
header, data = self._read_until_valid_crc_check()
|
||||
# TODO
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
|
||||
|
||||
def _establish_new_channel(self):
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
# Send channel allocation request
|
||||
channel_id_request_nonce = os.urandom(8)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
PacketHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
)
|
||||
|
||||
# Read channel allocation response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
):
|
||||
print("TODO raise exception here, I guess")
|
||||
|
||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
|
||||
# Send handshake init request
|
||||
ha_init_req_header = PacketHeader(0, self.channel_id, 36)
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake init response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_1()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
print("Received message is not a valid handshake init response message")
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
print("noise_tag: ", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
ck, k = _hkdf(
|
||||
PROTOCOL_NAME,
|
||||
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
try:
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||
except Exception as e:
|
||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
)
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||
)
|
||||
msg_data = self.mapping.encode_without_wire_type(
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
pairing_methods=[
|
||||
messages.ThpPairingMethod.NoMethod,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload)
|
||||
ha_completion_req_header = PacketHeader(
|
||||
0x12,
|
||||
self.channel_id,
|
||||
len(encrypted_host_static_pubkey)
|
||||
+ len(encrypted_payload)
|
||||
+ CHECKSUM_LENGTH,
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
print("Received message is not a valid handshake completion response")
|
||||
self._send_ack_2()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
# Send StartPairingReqest message
|
||||
message = messages.ThpStartPairingRequest()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
maaa = self.mapping.decode(msg_type, msg_data)
|
||||
self._send_ack_1()
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
self.has_valid_channel = True
|
||||
|
||||
def _get_control_byte(self) -> bytes:
|
||||
return b"\x42"
|
||||
|
||||
def _send_ack_1(self):
|
||||
header = PacketHeader(0x20, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _send_ack_2(self):
|
||||
header = PacketHeader(0x28, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _encrypt_and_write(
|
||||
self,
|
||||
session_id: int,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
ctrl_byte: int = 0x04,
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
|
||||
sid = session_id.to_bytes(1, "big")
|
||||
msg_type = message_type.to_bytes(2, "big")
|
||||
data = sid + msg_type + message_data
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = PacketHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, header, encrypted_message
|
||||
)
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
aes_ctx = AESGCM(self.key_response)
|
||||
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
self.nonce_response += 1
|
||||
|
||||
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
int.to_bytes(session_id, 1, "big"),
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> t.Tuple[PacketHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||
LOG.debug("Received a message with invalid checksum")
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
|
||||
return header, payload
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: PacketHeader, payload: bytes, original_nonce: bytes
|
||||
) -> bool:
|
||||
if not header.is_channel_allocation_response():
|
||||
print("Received message is not a channel allocation response")
|
||||
return False
|
||||
if len(payload) < 10:
|
||||
print("Invalid channel allocation response payload")
|
||||
return False
|
||||
if payload[:8] != original_nonce:
|
||||
print("Invalid channel allocation response payload (nonce mismatch)")
|
||||
return False
|
||||
return True
|
||||
|
||||
class ControlByteType(IntEnum):
|
||||
CHANNEL_ALLOCATION_RES = 1
|
||||
HANDSHAKE_INIT_RES = 2
|
||||
HANDSHAKE_COMP_RES = 3
|
||||
ACK = 4
|
||||
ENCRYPTED_TRANSPORT = 5
|
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...messages import Features, Initialize
|
||||
from .protocol_and_channel import ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .client import NewTrezorClient
|
||||
|
||||
|
||||
class Session:
|
||||
features: Features
|
||||
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
self.client = client
|
||||
self.id = id
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
) -> Session:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, b"")
|
||||
cls.features = session.call(
|
||||
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
|
||||
Initialize()
|
||||
)
|
||||
session.id = cls.features.session_id
|
||||
return session
|
||||
|
||||
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
|
||||
# if should_reinit:
|
||||
# self.initialize() # TODO
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
self.client.protocol.write(msg)
|
||||
return self.client.protocol.read()
|
||||
|
||||
|
||||
class SessionV2(Session):
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
self.channel = client.protocol.get_channel()
|
||||
self.sid = self._convert_id_to_sid(id)
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
self.channel.write(self.sid, msg)
|
||||
return self.channel.read(self.sid)
|
||||
|
||||
def _convert_id_to_sid(self, id: bytes) -> int:
|
||||
return int.from_bytes(id, "big") # TODO update to extract only sid
|
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from typing import TYPE_CHECKING, Iterable, Type, TypeVar
|
||||
|
||||
from ...exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="NewTransport")
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
pass
|
||||
|
||||
|
||||
class NewTransport:
|
||||
PATH_PREFIX: str
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str: ...
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int]
|
@ -0,0 +1,162 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ...log import DUMP_PACKETS
|
||||
from .. import TransportException
|
||||
from .transport import NewTransport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models import TrezorModel
|
||||
|
||||
SOCKET_TIMEOUT = 10
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(NewTransport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
else:
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
self.socket: socket.socket | None = None
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
f"No Trezor device found at address {d.get_path()}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise TransportException(f"Error opening {d.get_path()}") from e
|
||||
|
||||
finally:
|
||||
d.close()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
return [cls._try_path(default_path)]
|
||||
except TransportException:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
||||
try:
|
||||
address = path.replace(f"{cls.PATH_PREFIX}:", "")
|
||||
return cls._try_path(address)
|
||||
except TransportException:
|
||||
if not prefix_search:
|
||||
raise
|
||||
|
||||
if prefix_search:
|
||||
return super().find_by_path(path, prefix_search)
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.connect(self.device)
|
||||
self.socket.settimeout(SOCKET_TIMEOUT)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
chunk = self.socket.recv(64)
|
||||
break
|
||||
except socket.timeout:
|
||||
continue
|
||||
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
@ -0,0 +1,167 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List
|
||||
|
||||
from ...log import DUMP_PACKETS
|
||||
from ...models import TREZORS, TrezorModel
|
||||
from .. import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import usb1
|
||||
|
||||
USB_IMPORTED = True
|
||||
except Exception as e:
|
||||
LOG.warning(f"WebUSB transport is disabled: {e}")
|
||||
USB_IMPORTED = False
|
||||
|
||||
INTERFACE = 0
|
||||
ENDPOINT = 1
|
||||
DEBUG_INTERFACE = 1
|
||||
DEBUG_ENDPOINT = 2
|
||||
|
||||
|
||||
class WebUsbTransport(NewTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
if sys.platform.startswith("linux"):
|
||||
args = (UDEV_RULES_STR,)
|
||||
else:
|
||||
args = ()
|
||||
raise IOError("Cannot open device", *args)
|
||||
try:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
self.handle.releaseInterface(self.interface)
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.handle is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
||||
self.handle.interruptWrite(self.endpoint, chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
chunk = self.handle.interruptRead(endpoint, 64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
||||
|
||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||
configurationId = 0
|
||||
altSettingId = 0
|
||||
return (
|
||||
dev[configurationId][INTERFACE][altSettingId].getClass()
|
||||
== usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
|
||||
)
|
||||
|
||||
|
||||
def dev_to_str(dev: "usb1.USBDevice") -> str:
|
||||
return ":".join(
|
||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||
)
|
Loading…
Reference in new issue