From f6ff8529c6f101f2c1244a5280849f9ec1fab433 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 30 Aug 2024 10:56:25 +0200 Subject: [PATCH] TEMPORARY WIP - trezorlib --- python/src/trezorlib/cli/firmware.py | 1 + python/src/trezorlib/cli/trezorctl.py | 1 + python/src/trezorlib/client.py | 38 +- python/src/trezorlib/debuglink.py | 8 +- python/src/trezorlib/mapping.py | 3 +- python/src/trezorlib/tools.py | 1 + python/src/trezorlib/transport/protocol.py | 4 +- python/src/trezorlib/transport/protocol_v1.py | 3 + python/src/trezorlib/transport/protocol_v2.py | 340 ++++++++++++++++++ .../trezorlib/transport/thp/packet_header.py | 82 +++++ python/src/trezorlib/transport/thp/thp_io.py | 89 +++++ python/src/trezorlib/transport/webusb.py | 2 + 12 files changed, 562 insertions(+), 10 deletions(-) create mode 100644 python/src/trezorlib/transport/thp/packet_header.py create mode 100644 python/src/trezorlib/transport/thp/thp_io.py diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f283..69ce57ce92 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -653,6 +653,7 @@ def update( against downloaded firmware fingerprint. Otherwise fingerprint is checked against data.trezor.io information, if available. """ + print("client context") with obj.client_context() as client: if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index f3a037652e..c2aac2a0ed 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -291,6 +291,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: client = TrezorClient(transport, ui=ui.ClickUI()) description = format_device_name(client.features) client.end_session() + print("after end session") except DeviceIsBusy: description = "Device is in use by another process" except Exception: diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index d5df4d0ba5..2eb5419acc 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -134,12 +134,24 @@ class TrezorClient(Generic[UI]): self.session_id = session_id if _init_device: self.init_device(session_id=session_id, derive_cardano=derive_cardano) + self.resume_session() def open(self) -> None: if self.session_counter == 0: + session_id = self.transport.resume_session(b"") + if self.session_id != session_id: + print("Failed to resume session, allocated a new session") + self.session_id = session_id self.transport.deprecated_begin_session() self.session_counter += 1 + def resume_session(self) -> None: + print("resume session") + new_id = self.transport.resume_session(self.session_id or b"") + if self.session_id != new_id: + print("Failed to resume session, allocated a new session") + self.session_id = new_id + def close(self) -> None: self.session_counter = max(self.session_counter - 1, 0) if self.session_counter == 0: @@ -151,8 +163,13 @@ class TrezorClient(Generic[UI]): def call_raw(self, msg: "MessageType") -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 + print("self.call_raw-start") + self._raw_write(msg) - return self._raw_read() + print("self.call_raw-after write") + x = self._raw_read() + print("self.call_raw-end") + return x def _raw_write(self, msg: "MessageType") -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 @@ -169,7 +186,9 @@ class TrezorClient(Generic[UI]): def _raw_read(self) -> "MessageType": __tracebackhide__ = True # for pytest # pylint: disable=W0612 + print("raw read - start") msg_type, msg_bytes = self.transport.read() + print("type/data", msg_type, msg_bytes) LOG.log( DUMP_BYTES, f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", @@ -253,6 +272,7 @@ class TrezorClient(Generic[UI]): @session def call(self, msg: "MessageType") -> "MessageType": + print("self.call-start") self.check_firmware_version() resp = self.call_raw(msg) while True: @@ -263,10 +283,13 @@ class TrezorClient(Generic[UI]): elif isinstance(resp, messages.ButtonRequest): resp = self._callback_button(resp) elif isinstance(resp, messages.Failure): + print("self.call-failure") + if resp.code == messages.FailureType.ActionCancelled: raise exceptions.Cancelled raise exceptions.TrezorFailure(resp) else: + print("self.call-end") return resp def _refresh_features(self, features: messages.Features) -> None: @@ -311,7 +334,7 @@ class TrezorClient(Generic[UI]): self._refresh_features(resp) return resp - @session + # @session def init_device( self, *, @@ -352,11 +375,14 @@ class TrezorClient(Generic[UI]): elif session_id is not None: self.session_id = session_id + print("before init conn") + resp = self.transport.initialize_connection( mapping=self.mapping, session_id=session_id, derive_cardano=derive_cardano, ) + print("here") if isinstance(resp, messages.Failure): # can happen if `derive_cardano` does not match the current session raise exceptions.TrezorFailure(resp) @@ -377,6 +403,7 @@ class TrezorClient(Generic[UI]): # exchange happens. reported_session_id = resp.session_id self._refresh_features(resp) + print("there:", reported_session_id) return reported_session_id def is_outdated(self) -> bool: @@ -467,14 +494,19 @@ class TrezorClient(Generic[UI]): This is a no-op in bootloader mode, as it does not support session management. """ # since: 2.3.4, 1.9.4 + print("end session") try: if not self.features.bootloader_mode: - self.call(messages.EndSession()) + self.transport.end_session(self.session_id or b"") + # self.call(messages.EndSession()) except exceptions.TrezorFailure: # A failure most likely means that the FW version does not support # the EndSession call. We ignore the failure and clear the local session_id. # The client-side end result is identical. pass + except ValueError as e: + print(e) + print(e.args) self.session_id = None @session diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index e23c236678..737a201e7b 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -48,7 +48,7 @@ from .client import TrezorClient from .exceptions import TrezorFailure from .log import DUMP_BYTES from .messages import DebugWaitType -from .tools import expect, session +from .tools import expect if TYPE_CHECKING: from typing_extensions import Protocol @@ -1086,7 +1086,7 @@ class TrezorClientDebugLink(TrezorClient): """ if not self.in_with_statement: raise RuntimeError("Must be called inside 'with' statement") - + if input_flow is None: self.ui.input_flow = None return @@ -1287,7 +1287,7 @@ class TrezorClientDebugLink(TrezorClient): # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + self.transport.deprecated_begin_session() try: self.transport.write(*cancel_msg) @@ -1302,7 +1302,7 @@ class TrezorClientDebugLink(TrezorClient): except Exception: pass finally: - self.transport.end_session() + self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index e4a2d1e2ab..05f214afcc 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -63,9 +63,10 @@ class ProtobufMapping: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + print("wire type", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) + print("test") return wire_type, buf.getvalue() def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 4fd1558ec2..e00fde6fcb 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -297,6 +297,7 @@ def session( return f(client, *args, **kwargs) finally: client.close() + print("wrap end") return wrapped_f diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index dcad8378e0..b619012cae 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -95,7 +95,7 @@ class Protocol: def resume_session(self, session_id: bytes) -> bytes: raise NotImplementedError - def end_session(self, session_id: bytes) -> bytes: + def end_session(self, session_id: bytes) -> None: raise NotImplementedError # XXX we might be able to remove this now that TrezorClient does session handling @@ -147,7 +147,7 @@ class ProtocolBasedTransport(Transport): def resume_session(self, session_id: bytes) -> bytes: return self.protocol.resume_session(session_id) - def end_session(self, session_id: bytes) -> bytes: + def end_session(self, session_id: bytes) -> None: return self.protocol.end_session(session_id) def deprecated_begin_session(self) -> None: diff --git a/python/src/trezorlib/transport/protocol_v1.py b/python/src/trezorlib/transport/protocol_v1.py index c5d18c0042..7896e974dc 100644 --- a/python/src/trezorlib/transport/protocol_v1.py +++ b/python/src/trezorlib/transport/protocol_v1.py @@ -70,3 +70,6 @@ class ProtocolV1(Protocol): if chunk[:1] != b"?": raise RuntimeError("Unexpected magic characters") return chunk[1:] + + def end_session(self, session_id: bytes) -> None: + return super().end_session(session_id) diff --git a/python/src/trezorlib/transport/protocol_v2.py b/python/src/trezorlib/transport/protocol_v2.py index 455d4e1a01..fa4d468e3e 100644 --- a/python/src/trezorlib/transport/protocol_v2.py +++ b/python/src/trezorlib/transport/protocol_v2.py @@ -1,6 +1,346 @@ +import hashlib +import hmac +import logging +import os +from binascii import hexlify +from enum import IntEnum +from typing import Optional, Tuple + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from .. import messages +from ..mapping import ProtobufMapping +from ..protobuf import MessageType from ..transport.protocol import Handle, Protocol +from .thp import checksum, curve25519, thp_io +from .thp.checksum import CHECKSUM_LENGTH +from .thp.packet_header import PacketHeader + +LOG = logging.getLogger(__name__) + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") class ProtocolV2(Protocol): def __init__(self, handle: Handle) -> None: super().__init__(handle) + + def initialize_connection( + self, + mapping: ProtobufMapping, + session_id: Optional[bytes] = None, + derive_caradano: Optional[bool] = None, + ): + self.session_id: int = 0 + self.sync_bit_send: int = 0 + self.sync_bit_receive: int = 0 + self.mapping = mapping + # Send channel allocation request + channel_id_request_nonce = os.urandom(8) + thp_io.write_payload_to_wire_and_add_checksum( + self.handle, + PacketHeader.get_channel_allocation_request_header(12), + channel_id_request_nonce, + ) + + # Read channel allocation response + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, channel_id_request_nonce + ): + print("TODO raise exception here, I guess") + + self.cid = int.from_bytes(payload[8:10], "big") + self.device_properties = payload[10:] + + # Send handshake init request + ha_init_req_header = PacketHeader(0, self.cid, 36) + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + thp_io.write_payload_to_wire_and_add_checksum( + self.handle, ha_init_req_header, host_ephemeral_pubkey + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake init response + header, payload = self._read_until_valid_crc_check() + self._send_ack_1() + + if not header.is_handshake_init_response(): + print("Received message is not a valid handshake init response message") + + trezor_ephemeral_pubkey = payload[:32] + encrypted_trezor_static_pubkey = payload[32:80] + noise_tag = payload[80:96] + + # TODO check noise tag + print("noise_tag: ", hexlify(noise_tag).decode()) + + # Prepare and send handshake completion request + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + # print("masked_key", hexlify(trezor_masked_static_pubkey).decode()) + except Exception as e: + print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + pairing_methods=[ + messages.ThpPairingMethod.NoMethod, + ] + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload) + ha_completion_req_header = PacketHeader( + 0x12, + self.cid, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.handle, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + print("Received message is not a valid handshake completion response") + self._send_ack_2() + + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request: int = 0 + self.nonce_response: int = 1 + + # Send StartPairingReqest message + message = messages.ThpStartPairingRequest() + message_type, message_data = mapping.encode(message) + + self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + print("Received message is not a valid ACK ") + + # Read + _, msg_type, msg_data = self.read_and_decrypt() + maaa = mapping.decode(msg_type, msg_data) + self._send_ack_1() + + assert isinstance(maaa, messages.ThpEndResponse) + + # Send get features + message = messages.GetFeatures() + message_type, message_data = mapping.encode(message) + + self.session_id: int = 0 + self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14) + _ = thp_io.read(self.handle) + session_id, msg_type, msg_data = self.read_and_decrypt() + features = mapping.decode(msg_type, msg_data) + assert isinstance(features, messages.Features) + features.session_id = int.to_bytes(self.cid, 2, "big") + session_id + self._send_ack_2() + return features + + def _encrypt_and_write( + self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04 + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + data = self.session_id.to_bytes(1, "big") + message_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = PacketHeader( + ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.handle, header, encrypted_message + ) + + def _send_ack_1(self): + header = PacketHeader(0x20, self.cid, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"") + + def _send_ack_2(self): + header = PacketHeader(0x28, self.cid, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"") + + def _write_message(self, message: MessageType, mapping: ProtobufMapping): + try: + message_type, message_data = mapping.encode(message) + self.write(message_type, message_data) + except Exception as e: + print(type(e)) + + def write(self, message_type: int, message_data: bytes) -> None: + data = ( + self.session_id.to_bytes(1, "big") + + message_type.to_bytes(2, "big") + + message_data + ) + ctrl_byte = 0x04 + self._write_and_encrypt(data, ctrl_byte) + + def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None: + aes_ctx = AESGCM(self.key_request) + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_data = aes_ctx.encrypt(nonce, data, b"") + header = PacketHeader( + ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.handle, header, encrypted_data + ) + + def read_and_decrypt(self) -> Tuple[bytes, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if not header.is_encrypted_transport(): + print("Trying to decrypt not encrypted message!") + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + int.to_bytes(session_id, 1, "big"), + int.from_bytes(message_type, "big"), + message_data, + ) + + def end_session(self, session_id: bytes) -> None: + pass + + def resume_session(self, session_id: bytes) -> bytes: + print("protocol 2 resume session") + return self.start_session("") + + def start_session(self, passphrase: str) -> bytes: + try: + msg = messages.ThpCreateNewSession(passphrase=passphrase) + except Exception as e: + print(e) + print("s") + + self._write_message(msg, self.mapping) + print("p") + response_type, response_data = self._read_until_valid_crc_check() + print(response_type, response_data) + return b"" + + def read(self) -> Tuple[int, bytes]: + header, raw_payload, chksum = thp_io.read(self.handle) + print("Read message", hexlify(raw_payload)) + return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change + + def _get_control_byte(self) -> bytes: + return b"\x42" + + def _read_until_valid_crc_check( + self, + ) -> Tuple[PacketHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.handle) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + print(hexlify(header.to_bytes_init() + payload + chksum)) + LOG.debug("Received a message with invalid checksum") + header, payload, chksum = thp_io.read(self.handle) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: PacketHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + print("Received message is not a channel allocation response") + return False + if len(payload) < 10: + print("Invalid channel allocation response payload") + return False + if payload[:8] != original_nonce: + print("Invalid channel allocation response payload (nonce mismatch)") + return False + return True + + class ControlByteType(IntEnum): + CHANNEL_ALLOCATION_RES = 1 + HANDSHAKE_INIT_RES = 2 + HANDSHAKE_COMP_RES = 3 + ACK = 4 + ENCRYPTED_TRANSPORT = 5 diff --git a/python/src/trezorlib/transport/thp/packet_header.py b/python/src/trezorlib/transport/thp/packet_header.py new file mode 100644 index 0000000000..b282f9f46f --- /dev/null +++ b/python/src/trezorlib/transport/thp/packet_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class PacketHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 0000000000..b62ebb87bb --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,89 @@ +import struct +from binascii import hexlify +from typing import Tuple + +from ..protocol import Handle +from ..thp import checksum +from .packet_header import PacketHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +PACKET_LENGTH = 64 +CHECKSUM_LENGTH = 4 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + handle: Handle, header: PacketHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(handle, header, data) + print("WOO") + + +def write_payload_to_wire( + handle: Handle, header: PacketHeader, transport_payload: bytes +): + print("tttt") + handle.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH] + print("x") + chunk = chunk.ljust(PACKET_LENGTH, b"\x00") + print("y") + print(hexlify(chunk)) + handle.write_chunk(chunk) + print("fgh") + + buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :] + while buffer: + chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH] + chunk = chunk.ljust(PACKET_LENGTH, b"\x00") + handle.write_chunk(chunk) + buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :] + + +def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]: + buffer = bytearray() + # Read header with first part of message data + header, first_chunk = read_first(handle) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(handle, header.cid)) + # print("buffer read (data):", hexlify(buffer).decode()) + # print("buffer len (data):", datalen) + # TODO check checksum?? or do not strip ? + data_len = header.data_length - CHECKSUM_LENGTH + return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH] + + +def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]: + chunk = handle.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return PacketHeader(ctrl_byte, cid, data_length), data + + +def read_next(handle: Handle, cid: int) -> bytes: + chunk = handle.read_chunk() + ctrl_byte, read_cid = struct.unpack( + PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 31c7106475..5ef43d6edc 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -69,7 +69,9 @@ class WebUsbHandle: self.handle = None def write_chunk(self, chunk: bytes) -> None: + print("ti") assert self.handle is not None + print("te") if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")