diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index c2aac2a0e..83e1837f2 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,10 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click -from .. import __version__, log, messages, protobuf, ui +from .. import __version__, log, messages, protobuf from ..client import TrezorClient -from ..transport import DeviceIsBusy, enumerate_devices +from ..transport import DeviceIsBusy, new_enumerate_devices +from ..transport.new.client import NewTrezorClient from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -54,7 +55,7 @@ from . import ( F = TypeVar("F", bound=Callable) if TYPE_CHECKING: - from ..transport import Transport + from ..transport.new.transport import NewTransport LOG = logging.getLogger(__name__) @@ -281,16 +282,18 @@ def format_device_name(features: messages.Features) -> str: @cli.command(name="list") @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") -def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: +def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + return new_enumerate_devices() - for transport in enumerate_devices(): + for transport in new_enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) - description = format_device_name(client.features) - client.end_session() + print("test A") + client = NewTrezorClient(transport) + session = client.get_session() + description = format_device_name(session.features) + # client.end_session() print("after end session") except DeviceIsBusy: description = "Device is in use by another process" diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 737a201e7..edd12b209 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -14,6 +14,8 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import json import logging import re @@ -1096,6 +1098,7 @@ class TrezorClientDebugLink(TrezorClient): if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow + assert input_flow is not None input_flow.send(None) # start the generator def watch_layout(self, watch: bool = True) -> None: diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 05f214afc..3998f1443 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -66,7 +66,6 @@ class ProtobufMapping: print("wire type", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) - print("test") return wire_type, buf.getvalue() def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index c4f276a09..7dcda7eb2 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -14,17 +14,10 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Type, TypeVar from ..exceptions import TrezorException from ..mapping import ProtobufMapping @@ -82,8 +75,8 @@ class Transport: def initialize_connection( self, mapping: "ProtobufMapping", - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, + session_id: bytes | None = None, + derive_cardano: bool | None = None, ): raise NotImplementedError @@ -113,7 +106,7 @@ class Transport: @classmethod def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None + cls: Type["T"], models: Iterable["TrezorModel"] | None = None ) -> Iterable["T"]: raise NotImplementedError @@ -145,8 +138,21 @@ def all_transports() -> Iterable[Type["Transport"]]: return set(t for t in transports if t.ENABLED) +def all_new_transports() -> Iterable[Type["NewTransport"]]: + # from .bridge import BridgeTransport + # from .hid import HidTransport + from .new.udp import UdpTransport + from .new.webusb import WebUsbTransport + + transports: Tuple[Type["NewTransport"], ...] = ( + UdpTransport, + WebUsbTransport, + ) + return set(t for t in transports if t.ENABLED) + + def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, + models: Iterable["TrezorModel"] | None = None, ) -> Sequence["Transport"]: devices: List["Transport"] = [] for transport in all_transports(): @@ -163,9 +169,28 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +from .new.transport import NewTransport + + +def new_enumerate_devices( + models: Iterable["TrezorModel"] | None = None, +) -> Sequence["NewTransport"]: + devices: List["NewTransport"] = [] + for transport in all_new_transports(): + name = transport.__name__ + try: + found = list(transport.enumerate(models)) + LOG.info(f"Enumerating {name}: found {len(found)} devices") + devices.extend(found) + except NotImplementedError: + LOG.error(f"{name} does not implement device enumeration") + except Exception as e: + excname = e.__class__.__name__ + LOG.error(f"Failed to enumerate {name}. {excname}: {e}") + return devices + + +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/new/channel_data.py b/python/src/trezorlib/transport/new/channel_data.py new file mode 100644 index 000000000..08cdfe35d --- /dev/null +++ b/python/src/trezorlib/transport/new/channel_data.py @@ -0,0 +1,11 @@ +from __future__ import annotations + + +class ChannelData: + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + channel_id: bytes + sync_bit_send: int + sync_bit_receive: int diff --git a/python/src/trezorlib/transport/new/client.py b/python/src/trezorlib/transport/new/client.py new file mode 100644 index 000000000..025008466 --- /dev/null +++ b/python/src/trezorlib/transport/new/client.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import logging + +from ... import mapping +from ...mapping import ProtobufMapping +from .channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel, ProtocolV1 +from .protocol_v2 import ProtocolV2 +from .session import Session, SessionV1, SessionV2 +from .transport import NewTransport + +LOG = logging.getLogger(__name__) + + +class NewTrezorClient: + def __init__( + self, + transport: NewTransport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, + ) -> None: + self.transport = transport + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING + else: + self.mapping = protobuf_mapping + print("test B") + + if protocol is None: + print("test C") + self.protocol = self._get_protocol() + else: + self.protocol = protocol + + @classmethod + def resume( + cls, transport: NewTransport, channel_data: ChannelData + ) -> NewTrezorClient: ... + + def get_session( + self, passphrase: str = "", derive_cardano: bool = False + ) -> Session: + if isinstance(self.protocol, ProtocolV1): + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + return SessionV2.new(self, passphrase, derive_cardano) + raise NotImplementedError # TODO + + def resume_session(self, session_id: bytes) -> Session: + raise NotImplementedError # TODO + + def _get_protocol(self) -> ProtocolAndChannel: + + from ... import mapping, messages + from ...messages import FailureType + from .protocol_and_channel import ProtocolV1 + + self.transport.open() + + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) + + protocol.write(messages.Initialize()) + + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + print("test F1") + if ( + response.code == FailureType.UnexpectedMessage + and response.message == "Invalid protocol" + ): + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol diff --git a/python/src/trezorlib/transport/new/protocol_and_channel.py b/python/src/trezorlib/transport/new/protocol_and_channel.py new file mode 100644 index 000000000..083d87928 --- /dev/null +++ b/python/src/trezorlib/transport/new/protocol_and_channel.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ...log import DUMP_BYTES +from ...mapping import ProtobufMapping +from .channel_data import ChannelData +from .transport import NewTransport + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + def __init__( + self, + transport: NewTransport, + mapping: ProtobufMapping, + channel_keys: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_keys + + def close(self) -> None: ... + + # def write(self, session_id: bytes, msg: t.Any) -> None: ... + + # def read(self, session_id: bytes) -> t.Any: ... + + def get_channel_keys(self) -> ChannelData: ... + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + print("wooooo") + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + print("wooooo") + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] + + +class Channel: + id: int + channel_keys: ChannelData | None + + def __init__(self, id: int, keys: ChannelData) -> None: + self.id = id + self.channel_keys = keys + + def read(self) -> t.Any: ... + def write(self, msg: t.Any) -> None: ... diff --git a/python/src/trezorlib/transport/new/protocol_v2.py b/python/src/trezorlib/transport/new/protocol_v2.py new file mode 100644 index 000000000..f5d5c0145 --- /dev/null +++ b/python/src/trezorlib/transport/new/protocol_v2.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify +from enum import IntEnum + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import messages +from ...mapping import ProtobufMapping +from ..thp import checksum, curve25519, thp_io +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.packet_header import PacketHeader +from .channel_data import ChannelData +from .protocol_and_channel import Channel, ProtocolAndChannel +from .transport import NewTransport + +LOG = logging.getLogger(__name__) + +MANAGEMENT_SESSION_ID: int = 0 + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + channel_id: int + sync_bit_send: int + sync_bit_receive: int + has_valid_channel: bool = False + + def __init__( + self, + transport: NewTransport, + mapping: ProtobufMapping, + channel_keys: ChannelData | None = None, + ) -> None: + super().__init__(transport, mapping, channel_keys) + self.channel: Channel | None = None + + def get_channel(self) -> ProtocolV2: + if not self.has_valid_channel: + self._establish_new_channel() + return self + + def read(self, session_id: int) -> t.Any: + header, data = self._read_until_valid_crc_check() + # TODO + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte + + def _establish_new_channel(self): + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + # Send channel allocation request + channel_id_request_nonce = os.urandom(8) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + PacketHeader.get_channel_allocation_request_header(12), + channel_id_request_nonce, + ) + + # Read channel allocation response + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, channel_id_request_nonce + ): + print("TODO raise exception here, I guess") + + self.channel_id = int.from_bytes(payload[8:10], "big") + self.device_properties = payload[10:] + + # Send handshake init request + ha_init_req_header = PacketHeader(0, self.channel_id, 36) + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake init response + header, payload = self._read_until_valid_crc_check() + self._send_ack_1() + + if not header.is_handshake_init_response(): + print("Received message is not a valid handshake init response message") + + trezor_ephemeral_pubkey = payload[:32] + encrypted_trezor_static_pubkey = payload[32:80] + noise_tag = payload[80:96] + + # TODO check noise tag + print("noise_tag: ", hexlify(noise_tag).decode()) + + # Prepare and send handshake completion request + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + # print("masked_key", hexlify(trezor_masked_static_pubkey).decode()) + except Exception as e: + print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + pairing_methods=[ + messages.ThpPairingMethod.NoMethod, + ] + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload) + ha_completion_req_header = PacketHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + print("Received message is not a valid handshake completion response") + self._send_ack_2() + + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + # Send StartPairingReqest message + message = messages.ThpStartPairingRequest() + message_type, message_data = self.mapping.encode(message) + + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read + _, msg_type, msg_data = self.read_and_decrypt() + maaa = self.mapping.decode(msg_type, msg_data) + self._send_ack_1() + + assert isinstance(maaa, messages.ThpEndResponse) + self.has_valid_channel = True + + def _get_control_byte(self) -> bytes: + return b"\x42" + + def _send_ack_1(self): + header = PacketHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_2(self): + header = PacketHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int = 0x04, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = PacketHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if not header.is_encrypted_transport(): + print("Trying to decrypt not encrypted message!") + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + int.to_bytes(session_id, 1, "big"), + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[PacketHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + print(hexlify(header.to_bytes_init() + payload + chksum)) + LOG.debug("Received a message with invalid checksum") + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: PacketHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + print("Received message is not a channel allocation response") + return False + if len(payload) < 10: + print("Invalid channel allocation response payload") + return False + if payload[:8] != original_nonce: + print("Invalid channel allocation response payload (nonce mismatch)") + return False + return True + + class ControlByteType(IntEnum): + CHANNEL_ALLOCATION_RES = 1 + HANDSHAKE_INIT_RES = 2 + HANDSHAKE_COMP_RES = 3 + ACK = 4 + ENCRYPTED_TRANSPORT = 5 diff --git a/python/src/trezorlib/transport/new/session.py b/python/src/trezorlib/transport/new/session.py new file mode 100644 index 000000000..6c04f9eed --- /dev/null +++ b/python/src/trezorlib/transport/new/session.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import typing as t + +from ...messages import Features, Initialize +from .protocol_and_channel import ProtocolV1 +from .protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from .client import NewTrezorClient + + +class Session: + features: Features + + def __init__(self, client: NewTrezorClient, id: bytes) -> None: + self.client = client + self.id = id + + @classmethod + def new( + cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: t.Any) -> t.Any: + raise NotImplementedError + + +class SessionV1(Session): + @classmethod + def new( + cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, b"") + cls.features = session.call( + # Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO + Initialize() + ) + session.id = cls.features.session_id + return session + + def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any: + # if should_reinit: + # self.initialize() # TODO + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + return self.client.protocol.read() + + +class SessionV2(Session): + def __init__(self, client: NewTrezorClient, id: bytes) -> None: + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + self.channel = client.protocol.get_channel() + self.sid = self._convert_id_to_sid(id) + + def call(self, msg: t.Any) -> t.Any: + self.channel.write(self.sid, msg) + return self.channel.read(self.sid) + + def _convert_id_to_sid(self, id: bytes) -> int: + return int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/new/transport.py b/python/src/trezorlib/transport/new/transport.py new file mode 100644 index 000000000..87c550227 --- /dev/null +++ b/python/src/trezorlib/transport/new/transport.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING, Iterable, Type, TypeVar + +from ...exceptions import TrezorException + +if TYPE_CHECKING: + from ...models import TrezorModel + + T = TypeVar("T", bound="NewTransport") + + +class TransportException(TrezorException): + pass + + +class NewTransport: + PATH_PREFIX: str + + @classmethod + def enumerate( + cls: Type["T"], models: Iterable["TrezorModel"] | None = None + ) -> Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): + + if device.get_path() == path: + return device + + if prefix_search and device.get_path().startswith(path): + return device + + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + + def get_path(self) -> str: ... + + def open(self) -> None: ... + + def close(self) -> None: ... + + def write_chunk(self, chunk: bytes) -> None: ... + + def read_chunk(self) -> bytes: ... + + CHUNK_SIZE: t.ClassVar[int] diff --git a/python/src/trezorlib/transport/new/udp.py b/python/src/trezorlib/transport/new/udp.py new file mode 100644 index 000000000..7e34b5d6c --- /dev/null +++ b/python/src/trezorlib/transport/new/udp.py @@ -0,0 +1,162 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2024 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . + +from __future__ import annotations + +import logging +import socket +import time +from typing import TYPE_CHECKING, Iterable, Tuple + +from ...log import DUMP_PACKETS +from .. import TransportException +from .transport import NewTransport + +if TYPE_CHECKING: + from ...models import TrezorModel + +SOCKET_TIMEOUT = 10 + +LOG = logging.getLogger(__name__) + + +class UdpTransport(NewTransport): + + DEFAULT_HOST = "127.0.0.1" + DEFAULT_PORT = 21324 + PATH_PREFIX = "udp" + ENABLED: bool = True + CHUNK_SIZE = 64 + + def __init__( + self, + device: str | None = None, + ) -> None: + if not device: + host = UdpTransport.DEFAULT_HOST + port = UdpTransport.DEFAULT_PORT + else: + devparts = device.split(":") + host = devparts[0] + port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT + self.device: Tuple[str, int] = (host, port) + + self.socket: socket.socket | None = None + super().__init__() + + @classmethod + def _try_path(cls, path: str) -> "UdpTransport": + d = cls(path) + try: + d.open() + if d.ping(): + return d + else: + raise TransportException( + f"No Trezor device found at address {d.get_path()}" + ) + except Exception as e: + raise TransportException(f"Error opening {d.get_path()}") from e + + finally: + d.close() + + @classmethod + def enumerate( + cls, _models: Iterable["TrezorModel"] | None = None + ) -> Iterable["UdpTransport"]: + default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" + try: + return [cls._try_path(default_path)] + except TransportException: + return [] + + @classmethod + def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport": + try: + address = path.replace(f"{cls.PATH_PREFIX}:", "") + return cls._try_path(address) + except TransportException: + if not prefix_search: + raise + + if prefix_search: + return super().find_by_path(path, prefix_search) + else: + raise TransportException(f"No UDP device at {path}") + + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) + + def open(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.connect(self.device) + self.socket.settimeout(SOCKET_TIMEOUT) + + def close(self) -> None: + if self.socket is not None: + self.socket.close() + self.socket = None + + def write_chunk(self, chunk: bytes) -> None: + assert self.socket is not None + if len(chunk) != 64: + raise TransportException("Unexpected data length") + LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}") + self.socket.sendall(chunk) + + def read_chunk(self) -> bytes: + assert self.socket is not None + while True: + try: + chunk = self.socket.recv(64) + break + except socket.timeout: + continue + LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") + if len(chunk) != 64: + raise TransportException(f"Unexpected chunk size: {len(chunk)}") + return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/new/webusb.py b/python/src/trezorlib/transport/new/webusb.py new file mode 100644 index 000000000..161d58dab --- /dev/null +++ b/python/src/trezorlib/transport/new/webusb.py @@ -0,0 +1,167 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2024 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . + +from __future__ import annotations + +import atexit +import logging +import sys +import time +from typing import Iterable, List + +from ...log import DUMP_PACKETS +from ...models import TREZORS, TrezorModel +from .. import UDEV_RULES_STR, DeviceIsBusy, TransportException +from .transport import NewTransport + +LOG = logging.getLogger(__name__) + +try: + import usb1 + + USB_IMPORTED = True +except Exception as e: + LOG.warning(f"WebUSB transport is disabled: {e}") + USB_IMPORTED = False + +INTERFACE = 0 +ENDPOINT = 1 +DEBUG_INTERFACE = 1 +DEBUG_ENDPOINT = 2 + + +class WebUsbTransport(NewTransport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + + self.device = device + self.debug = debug + + self.interface = DEBUG_INTERFACE if debug else INTERFACE + self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" + + def open(self) -> None: + self.handle = self.device.open() + if self.handle is None: + if sys.platform.startswith("linux"): + args = (UDEV_RULES_STR,) + else: + args = () + raise IOError("Cannot open device", *args) + try: + self.handle.claimInterface(self.interface) + except usb1.USBErrorAccess as e: + raise DeviceIsBusy(self.device) from e + + def close(self) -> None: + if self.handle is not None: + self.handle.releaseInterface(self.interface) + self.handle.close() + self.handle = None + + def write_chunk(self, chunk: bytes) -> None: + assert self.handle is not None + if len(chunk) != 64: + raise TransportException(f"Unexpected chunk size: {len(chunk)}") + LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}") + self.handle.interruptWrite(self.endpoint, chunk) + + def read_chunk(self) -> bytes: + assert self.handle is not None + endpoint = 0x80 | self.endpoint + while True: + chunk = self.handle.interruptRead(endpoint, 64) + if chunk: + break + else: + time.sleep(0.001) + LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}") + if len(chunk) != 64: + raise TransportException(f"Unexpected chunk size: {len(chunk)}") + return chunk + + def find_debug(self) -> "WebUsbTransport": + # For v1 protocol, find debug USB interface for the same serial number + return WebUsbTransport(self.device, debug=True) + + +def is_vendor_class(dev: "usb1.USBDevice") -> bool: + configurationId = 0 + altSettingId = 0 + return ( + dev[configurationId][INTERFACE][altSettingId].getClass() + == usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC + ) + + +def dev_to_str(dev: "usb1.USBDevice") -> str: + return ":".join( + str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() + ) diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index b619012ca..b62ee72cc 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -174,14 +174,14 @@ class ProtocolBasedTransport(Transport): response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) self.handle.close() if isinstance(response, messages.Failure): - from .protocol_v2 import ProtocolV2 + from .protocol_v2 import DeprecatedProtocolV2 if ( response.code == FailureType.UnexpectedMessage and response.message == "Invalid protocol" ): LOG.debug("Protocol V2 detected") - protocol = ProtocolV2(self.handle) + protocol = DeprecatedProtocolV2(self.handle) return protocol @@ -193,8 +193,8 @@ def _get_protocol(version: int, handle: Handle) -> Protocol: return ProtocolV1(handle) if version == PROTOCOL_VERSION_2: - from .protocol_v2 import ProtocolV2 + from .protocol_v2 import DeprecatedProtocolV2 - return ProtocolV2(handle) + return DeprecatedProtocolV2(handle) raise NotImplementedError diff --git a/python/src/trezorlib/transport/protocol_v2.py b/python/src/trezorlib/transport/protocol_v2.py index fa4d468e3..9bd45531f 100644 --- a/python/src/trezorlib/transport/protocol_v2.py +++ b/python/src/trezorlib/transport/protocol_v2.py @@ -1,19 +1,12 @@ import hashlib import hmac import logging -import os -from binascii import hexlify from enum import IntEnum from typing import Optional, Tuple -from cryptography.hazmat.primitives.ciphers.aead import AESGCM - -from .. import messages from ..mapping import ProtobufMapping from ..protobuf import MessageType from ..transport.protocol import Handle, Protocol -from .thp import checksum, curve25519, thp_io -from .thp.checksum import CHECKSUM_LENGTH from .thp.packet_header import PacketHeader LOG = logging.getLogger(__name__) @@ -40,7 +33,7 @@ def _get_iv_from_nonce(nonce: int) -> bytes: return bytes(4) + nonce.to_bytes(8, "big") -class ProtocolV2(Protocol): +class DeprecatedProtocolV2(Protocol): def __init__(self, handle: Handle) -> None: super().__init__(handle) @@ -50,191 +43,185 @@ class ProtocolV2(Protocol): session_id: Optional[bytes] = None, derive_caradano: Optional[bool] = None, ): - self.session_id: int = 0 - self.sync_bit_send: int = 0 - self.sync_bit_receive: int = 0 - self.mapping = mapping - # Send channel allocation request - channel_id_request_nonce = os.urandom(8) - thp_io.write_payload_to_wire_and_add_checksum( - self.handle, - PacketHeader.get_channel_allocation_request_header(12), - channel_id_request_nonce, - ) + # self.session_id: int = 0 + # self.sync_bit_send: int = 0 + # self.sync_bit_receive: int = 0 + # self.mapping = mapping + # # Send channel allocation request + # channel_id_request_nonce = os.urandom(8) + # thp_io.write_payload_to_wire_and_add_checksum( + # self.handle, + # PacketHeader.get_channel_allocation_request_header(12), + # channel_id_request_nonce, + # ) - # Read channel allocation response - header, payload = self._read_until_valid_crc_check() - if not self._is_valid_channel_allocation_response( - header, payload, channel_id_request_nonce - ): - print("TODO raise exception here, I guess") + # # Read channel allocation response + # header, payload = self._read_until_valid_crc_check() + # if not self._is_valid_channel_allocation_response( + # header, payload, channel_id_request_nonce + # ): + # print("TODO raise exception here, I guess") - self.cid = int.from_bytes(payload[8:10], "big") - self.device_properties = payload[10:] + # self.cid = int.from_bytes(payload[8:10], "big") + # self.device_properties = payload[10:] - # Send handshake init request - ha_init_req_header = PacketHeader(0, self.cid, 36) - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + # # Send handshake init request + # ha_init_req_header = PacketHeader(0, self.cid, 36) + # host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + # host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - thp_io.write_payload_to_wire_and_add_checksum( - self.handle, ha_init_req_header, host_ephemeral_pubkey - ) + # thp_io.write_payload_to_wire_and_add_checksum( + # self.handle, ha_init_req_header, host_ephemeral_pubkey + # ) - # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - print("Received message is not a valid ACK ") + # # Read ACK + # header, payload = self._read_until_valid_crc_check() + # if not header.is_ack() or len(payload) > 0: + # print("Received message is not a valid ACK ") - # Read handshake init response - header, payload = self._read_until_valid_crc_check() - self._send_ack_1() + # # Read handshake init response + # header, payload = self._read_until_valid_crc_check() + # self._send_ack_1() - if not header.is_handshake_init_response(): - print("Received message is not a valid handshake init response message") + # if not header.is_handshake_init_response(): + # print("Received message is not a valid handshake init response message") - trezor_ephemeral_pubkey = payload[:32] - encrypted_trezor_static_pubkey = payload[32:80] - noise_tag = payload[80:96] + # trezor_ephemeral_pubkey = payload[:32] + # encrypted_trezor_static_pubkey = payload[32:80] + # noise_tag = payload[80:96] - # TODO check noise tag - print("noise_tag: ", hexlify(noise_tag).decode()) + # # TODO check noise tag + # print("noise_tag: ", hexlify(noise_tag).decode()) - # Prepare and send handshake completion request - PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" - IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) - h = _sha256_of_two(h, host_ephemeral_pubkey) - h = _sha256_of_two(h, trezor_ephemeral_pubkey) - ck, k = _hkdf( - PROTOCOL_NAME, - curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), - ) + # # Prepare and send handshake completion request + # PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + # IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + # IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + # h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + # h = _sha256_of_two(h, host_ephemeral_pubkey) + # h = _sha256_of_two(h, trezor_ephemeral_pubkey) + # ck, k = _hkdf( + # PROTOCOL_NAME, + # curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + # ) - aes_ctx = AESGCM(k) - try: - trezor_masked_static_pubkey = aes_ctx.decrypt( - IV_1, encrypted_trezor_static_pubkey, h - ) - # print("masked_key", hexlify(trezor_masked_static_pubkey).decode()) - except Exception as e: - print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik - h = _sha256_of_two(h, encrypted_trezor_static_pubkey) - ck, k = _hkdf( - ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) - ) - aes_ctx = AESGCM(k) + # aes_ctx = AESGCM(k) + # try: + # trezor_masked_static_pubkey = aes_ctx.decrypt( + # IV_1, encrypted_trezor_static_pubkey, h + # ) + # # print("masked_key", hexlify(trezor_masked_static_pubkey).decode()) + # except Exception as e: + # print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik + # h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + # ck, k = _hkdf( + # ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + # ) + # aes_ctx = AESGCM(k) - tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) - h = _sha256_of_two(h, tag_of_empty_string) - # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + # tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + # h = _sha256_of_two(h, tag_of_empty_string) + # # TODO: search for saved credentials (or possibly not, as we skip pairing phase) - zeroes_32 = int.to_bytes(0, 32, "little") - temp_host_static_privkey = curve25519.get_private_key(zeroes_32) - temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) - aes_ctx = AESGCM(k) - encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) - h = _sha256_of_two(h, encrypted_host_static_pubkey) - ck, k = _hkdf( - ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) - ) - msg_data = mapping.encode_without_wire_type( - messages.ThpHandshakeCompletionReqNoisePayload( - pairing_methods=[ - messages.ThpPairingMethod.NoMethod, - ] - ) - ) + # zeroes_32 = int.to_bytes(0, 32, "little") + # temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + # temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + # aes_ctx = AESGCM(k) + # encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + # h = _sha256_of_two(h, encrypted_host_static_pubkey) + # ck, k = _hkdf( + # ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + # ) + # msg_data = mapping.encode_without_wire_type( + # messages.ThpHandshakeCompletionReqNoisePayload( + # pairing_methods=[ + # messages.ThpPairingMethod.NoMethod, + # ] + # ) + # ) - aes_ctx = AESGCM(k) + # aes_ctx = AESGCM(k) - encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) - h = _sha256_of_two(h, encrypted_payload) - ha_completion_req_header = PacketHeader( - 0x12, - self.cid, - len(encrypted_host_static_pubkey) - + len(encrypted_payload) - + CHECKSUM_LENGTH, - ) - thp_io.write_payload_to_wire_and_add_checksum( - self.handle, - ha_completion_req_header, - encrypted_host_static_pubkey + encrypted_payload, - ) + # encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + # h = _sha256_of_two(h, encrypted_payload) + # ha_completion_req_header = PacketHeader( + # 0x12, + # self.cid, + # len(encrypted_host_static_pubkey) + # + len(encrypted_payload) + # + CHECKSUM_LENGTH, + # ) + # thp_io.write_payload_to_wire_and_add_checksum( + # self.handle, + # ha_completion_req_header, + # encrypted_host_static_pubkey + encrypted_payload, + # ) - # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - print("Received message is not a valid ACK ") + # # Read ACK + # header, payload = self._read_until_valid_crc_check() + # if not header.is_ack() or len(payload) > 0: + # print("Received message is not a valid ACK ") - # Read handshake completion response, ignore payload as we do not care about the state - header, _ = self._read_until_valid_crc_check() - if not header.is_handshake_comp_response(): - print("Received message is not a valid handshake completion response") - self._send_ack_2() + # # Read handshake completion response, ignore payload as we do not care about the state + # header, _ = self._read_until_valid_crc_check() + # if not header.is_handshake_comp_response(): + # print("Received message is not a valid handshake completion response") + # self._send_ack_2() - self.key_request, self.key_response = _hkdf(ck, b"") - self.nonce_request: int = 0 - self.nonce_response: int = 1 + # self.key_request, self.key_response = _hkdf(ck, b"") + # self.nonce_request: int = 0 + # self.nonce_response: int = 1 - # Send StartPairingReqest message - message = messages.ThpStartPairingRequest() - message_type, message_data = mapping.encode(message) + # # Send StartPairingReqest message + # message = messages.ThpStartPairingRequest() + # message_type, message_data = mapping.encode(message) - self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data) + # self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data) - # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - print("Received message is not a valid ACK ") + # # Read ACK + # header, payload = self._read_until_valid_crc_check() + # if not header.is_ack() or len(payload) > 0: + # print("Received message is not a valid ACK ") - # Read - _, msg_type, msg_data = self.read_and_decrypt() - maaa = mapping.decode(msg_type, msg_data) - self._send_ack_1() + # # Read + # _, msg_type, msg_data = self.read_and_decrypt() + # maaa = mapping.decode(msg_type, msg_data) + # self._send_ack_1() - assert isinstance(maaa, messages.ThpEndResponse) + # assert isinstance(maaa, messages.ThpEndResponse) - # Send get features - message = messages.GetFeatures() - message_type, message_data = mapping.encode(message) + # # Send get features + # message = messages.GetFeatures() + # message_type, message_data = mapping.encode(message) - self.session_id: int = 0 - self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14) - _ = thp_io.read(self.handle) - session_id, msg_type, msg_data = self.read_and_decrypt() - features = mapping.decode(msg_type, msg_data) - assert isinstance(features, messages.Features) - features.session_id = int.to_bytes(self.cid, 2, "big") + session_id - self._send_ack_2() - return features + # self.session_id: int = 0 + # self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14) + # _ = thp_io.read(self.handle) + # session_id, msg_type, msg_data = self.read_and_decrypt() + # features = mapping.decode(msg_type, msg_data) + # assert isinstance(features, messages.Features) + # features.session_id = int.to_bytes(self.cid, 2, "big") + session_id + # self._send_ack_2() + # return features + ... def _encrypt_and_write( self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04 ) -> None: - assert self.key_request is not None - aes_ctx = AESGCM(self.key_request) - data = self.session_id.to_bytes(1, "big") + message_type + message_data - nonce = _get_iv_from_nonce(self.nonce_request) - self.nonce_request += 1 - encrypted_message = aes_ctx.encrypt(nonce, data, b"") - header = PacketHeader( - ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH - ) + # assert self.key_request is not None + # aes_ctx = AESGCM(self.key_request) + # data = self.session_id.to_bytes(1, "big") + message_type + message_data + # nonce = _get_iv_from_nonce(self.nonce_request) + # self.nonce_request += 1 + # encrypted_message = aes_ctx.encrypt(nonce, data, b"") + # header = PacketHeader( + # ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH + # ) - thp_io.write_payload_to_wire_and_add_checksum( - self.handle, header, encrypted_message - ) - - def _send_ack_1(self): - header = PacketHeader(0x20, self.cid, 4) - thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"") - - def _send_ack_2(self): - header = PacketHeader(0x28, self.cid, 4) - thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"") + # thp_io.write_payload_to_wire_and_add_checksum( + # self.handle, header, encrypted_message + # ) + ... def _write_message(self, message: MessageType, mapping: ProtobufMapping): try: @@ -244,43 +231,46 @@ class ProtocolV2(Protocol): print(type(e)) def write(self, message_type: int, message_data: bytes) -> None: - data = ( - self.session_id.to_bytes(1, "big") - + message_type.to_bytes(2, "big") - + message_data - ) - ctrl_byte = 0x04 - self._write_and_encrypt(data, ctrl_byte) + # data = ( + # self.session_id.to_bytes(1, "big") + # + message_type.to_bytes(2, "big") + # + message_data + # ) + # ctrl_byte = 0x04 + # self._write_and_encrypt(data, ctrl_byte) + ... def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None: - aes_ctx = AESGCM(self.key_request) - nonce = _get_iv_from_nonce(self.nonce_request) - self.nonce_request += 1 - encrypted_data = aes_ctx.encrypt(nonce, data, b"") - header = PacketHeader( - ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH - ) - thp_io.write_payload_to_wire_and_add_checksum( - self.handle, header, encrypted_data - ) + # aes_ctx = AESGCM(self.key_request) + # nonce = _get_iv_from_nonce(self.nonce_request) + # self.nonce_request += 1 + # encrypted_data = aes_ctx.encrypt(nonce, data, b"") + # header = PacketHeader( + # ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH + # ) + # thp_io.write_payload_to_wire_and_add_checksum( + # self.handle, header, encrypted_data + # ) + ... def read_and_decrypt(self) -> Tuple[bytes, int, bytes]: - header, raw_payload = self._read_until_valid_crc_check() - if not header.is_encrypted_transport(): - print("Trying to decrypt not encrypted message!") - aes_ctx = AESGCM(self.key_response) - nonce = _get_iv_from_nonce(self.nonce_response) - self.nonce_response += 1 + # header, raw_payload = self._read_until_valid_crc_check() + # if not header.is_encrypted_transport(): + # print("Trying to decrypt not encrypted message!") + # aes_ctx = AESGCM(self.key_response) + # nonce = _get_iv_from_nonce(self.nonce_response) + # self.nonce_response += 1 - message = aes_ctx.decrypt(nonce, raw_payload, b"") - session_id = message[0] - message_type = message[1:3] - message_data = message[3:] - return ( - int.to_bytes(session_id, 1, "big"), - int.from_bytes(message_type, "big"), - message_data, - ) + # message = aes_ctx.decrypt(nonce, raw_payload, b"") + # session_id = message[0] + # message_type = message[1:3] + # message_data = message[3:] + # return ( + # int.to_bytes(session_id, 1, "big"), + # int.from_bytes(message_type, "big"), + # message_data, + # ) + ... def end_session(self, session_id: bytes) -> None: pass @@ -290,22 +280,24 @@ class ProtocolV2(Protocol): return self.start_session("") def start_session(self, passphrase: str) -> bytes: - try: - msg = messages.ThpCreateNewSession(passphrase=passphrase) - except Exception as e: - print(e) - print("s") + # try: + # msg = messages.ThpCreateNewSession(passphrase=passphrase) + # except Exception as e: + # print(e) + # print("s") - self._write_message(msg, self.mapping) - print("p") - response_type, response_data = self._read_until_valid_crc_check() - print(response_type, response_data) - return b"" + # self._write_message(msg, self.mapping) + # print("p") + # response_type, response_data = self._read_until_valid_crc_check() + # print(response_type, response_data) + # return b"" + ... def read(self) -> Tuple[int, bytes]: - header, raw_payload, chksum = thp_io.read(self.handle) - print("Read message", hexlify(raw_payload)) - return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change + # header, raw_payload, chksum = thp_io.read(self.handle) + # print("Read message", hexlify(raw_payload)) + # return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change + ... def _get_control_byte(self) -> bytes: return b"\x42" @@ -313,16 +305,17 @@ class ProtocolV2(Protocol): def _read_until_valid_crc_check( self, ) -> Tuple[PacketHeader, bytes]: - is_valid = False - header, payload, chksum = thp_io.read(self.handle) - while not is_valid: - is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) - if not is_valid: - print(hexlify(header.to_bytes_init() + payload + chksum)) - LOG.debug("Received a message with invalid checksum") - header, payload, chksum = thp_io.read(self.handle) + # is_valid = False + # header, payload, chksum = thp_io.read(self.handle) + # while not is_valid: + # is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + # if not is_valid: + # print(hexlify(header.to_bytes_init() + payload + chksum)) + # LOG.debug("Received a message with invalid checksum") + # header, payload, chksum = thp_io.read(self.handle) - return header, payload + # return header, payload + ... def _is_valid_channel_allocation_response( self, header: PacketHeader, payload: bytes, original_nonce: bytes diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py index 8107c634e..48efac832 100644 --- a/python/src/trezorlib/transport/thp/thp_io.py +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -1,15 +1,12 @@ import struct -from binascii import hexlify from typing import Tuple -from ..protocol import Handle +from ..new.transport import NewTransport from ..thp import checksum from .packet_header import PacketHeader INIT_HEADER_LENGTH = 5 CONT_HEADER_LENGTH = 3 -PACKET_LENGTH = 64 -CHECKSUM_LENGTH = 4 MAX_PAYLOAD_LEN = 60000 MESSAGE_TYPE_LENGTH = 2 @@ -17,48 +14,54 @@ CONTINUATION_PACKET = 0x80 def write_payload_to_wire_and_add_checksum( - handle: Handle, header: PacketHeader, transport_payload: bytes + transport: NewTransport, header: PacketHeader, transport_payload: bytes ): chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) data = transport_payload + chksum - write_payload_to_wire(handle, header, data) + write_payload_to_wire(transport, header, data) def write_payload_to_wire( - handle: Handle, header: PacketHeader, transport_payload: bytes + transport: NewTransport, header: PacketHeader, transport_payload: bytes ): - handle.open() + transport.open() buffer = bytearray(transport_payload) - chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH] - chunk = chunk.ljust(PACKET_LENGTH, b"\x00") - handle.write_chunk(chunk) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) - buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :] + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] while buffer: - chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH] - chunk = chunk.ljust(PACKET_LENGTH, b"\x00") - handle.write_chunk(chunk) - buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :] + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] -def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]: +def read(transport: NewTransport) -> Tuple[PacketHeader, bytes, bytes]: buffer = bytearray() # Read header with first part of message data - header, first_chunk = read_first(handle) + header, first_chunk = read_first(transport) buffer.extend(first_chunk) # Read the rest of the message while len(buffer) < header.data_length: - buffer.extend(read_next(handle, header.cid)) + buffer.extend(read_next(transport, header.cid)) # print("buffer read (data):", hexlify(buffer).decode()) # print("buffer len (data):", datalen) # TODO check checksum?? or do not strip ? - data_len = header.data_length - CHECKSUM_LENGTH - return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH] + data_len = header.data_length - checksum.CHECKSUM_LENGTH + return ( + header, + buffer[:data_len], + buffer[data_len : data_len + checksum.CHECKSUM_LENGTH], + ) -def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]: - chunk = handle.read_chunk() +def read_first(transport: NewTransport) -> Tuple[PacketHeader, bytes]: + chunk = transport.read_chunk() try: ctrl_byte, cid, data_length = struct.unpack( PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] @@ -70,8 +73,8 @@ def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]: return PacketHeader(ctrl_byte, cid, data_length), data -def read_next(handle: Handle, cid: int) -> bytes: - chunk = handle.read_chunk() +def read_next(transport: NewTransport, cid: int) -> bytes: + chunk = transport.read_chunk() ctrl_byte, read_cid = struct.unpack( PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] )