wip trezorlib

M1nd3r/thp-improved
M1nd3r 1 week ago
parent 87d6407d26
commit 11309cccd0

@ -24,9 +24,10 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click
from .. import __version__, log, messages, protobuf, ui
from .. import __version__, log, messages, protobuf
from ..client import TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport import DeviceIsBusy, new_enumerate_devices
from ..transport.new.client import NewTrezorClient
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
@ -54,7 +55,7 @@ from . import (
F = TypeVar("F", bound=Callable)
if TYPE_CHECKING:
from ..transport import Transport
from ..transport.new.transport import NewTransport
LOG = logging.getLogger(__name__)
@ -281,16 +282,18 @@ def format_device_name(features: messages.Features) -> str:
@cli.command(name="list")
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
return new_enumerate_devices()
for transport in enumerate_devices():
for transport in new_enumerate_devices():
try:
client = TrezorClient(transport, ui=ui.ClickUI())
description = format_device_name(client.features)
client.end_session()
print("test A")
client = NewTrezorClient(transport)
session = client.get_session()
description = format_device_name(session.features)
# client.end_session()
print("after end session")
except DeviceIsBusy:
description = "Device is in use by another process"

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import json
import logging
import re
@ -1096,6 +1098,7 @@ class TrezorClientDebugLink(TrezorClient):
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
assert input_flow is not None
input_flow.send(None) # start the generator
def watch_layout(self, watch: bool = True) -> None:

@ -66,7 +66,6 @@ class ProtobufMapping:
print("wire type", wire_type)
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
print("test")
return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:

@ -14,17 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Type, TypeVar
from ..exceptions import TrezorException
from ..mapping import ProtobufMapping
@ -82,8 +75,8 @@ class Transport:
def initialize_connection(
self,
mapping: "ProtobufMapping",
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
session_id: bytes | None = None,
derive_cardano: bool | None = None,
):
raise NotImplementedError
@ -113,7 +106,7 @@ class Transport:
@classmethod
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
cls: Type["T"], models: Iterable["TrezorModel"] | None = None
) -> Iterable["T"]:
raise NotImplementedError
@ -145,8 +138,21 @@ def all_transports() -> Iterable[Type["Transport"]]:
return set(t for t in transports if t.ENABLED)
def all_new_transports() -> Iterable[Type["NewTransport"]]:
# from .bridge import BridgeTransport
# from .hid import HidTransport
from .new.udp import UdpTransport
from .new.webusb import WebUsbTransport
transports: Tuple[Type["NewTransport"], ...] = (
UdpTransport,
WebUsbTransport,
)
return set(t for t in transports if t.ENABLED)
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
models: Iterable["TrezorModel"] | None = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
for transport in all_transports():
@ -163,9 +169,28 @@ def enumerate_devices(
return devices
def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
from .new.transport import NewTransport
def new_enumerate_devices(
models: Iterable["TrezorModel"] | None = None,
) -> Sequence["NewTransport"]:
devices: List["NewTransport"] = []
for transport in all_new_transports():
name = transport.__name__
try:
found = list(transport.enumerate(models))
LOG.info(f"Enumerating {name}: found {len(found)} devices")
devices.extend(found)
except NotImplementedError:
LOG.error(f"{name} does not implement device enumeration")
except Exception as e:
excname = e.__class__.__name__
LOG.error(f"Failed to enumerate {name}. {excname}: {e}")
return devices
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
if path is None:
try:
return next(iter(enumerate_devices()))

@ -0,0 +1,11 @@
from __future__ import annotations
class ChannelData:
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
channel_id: bytes
sync_bit_send: int
sync_bit_receive: int

@ -0,0 +1,76 @@
from __future__ import annotations
import logging
from ... import mapping
from ...mapping import ProtobufMapping
from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel, ProtocolV1
from .protocol_v2 import ProtocolV2
from .session import Session, SessionV1, SessionV2
from .transport import NewTransport
LOG = logging.getLogger(__name__)
class NewTrezorClient:
def __init__(
self,
transport: NewTransport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None,
) -> None:
self.transport = transport
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
print("test B")
if protocol is None:
print("test C")
self.protocol = self._get_protocol()
else:
self.protocol = protocol
@classmethod
def resume(
cls, transport: NewTransport, channel_data: ChannelData
) -> NewTrezorClient: ...
def get_session(
self, passphrase: str = "", derive_cardano: bool = False
) -> Session:
if isinstance(self.protocol, ProtocolV1):
return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2):
return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
def resume_session(self, session_id: bytes) -> Session:
raise NotImplementedError # TODO
def _get_protocol(self) -> ProtocolAndChannel:
from ... import mapping, messages
from ...messages import FailureType
from .protocol_and_channel import ProtocolV1
self.transport.open()
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
print("test F1")
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"
):
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.transport, self.mapping)
return protocol

@ -0,0 +1,117 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ...log import DUMP_BYTES
from ...mapping import ProtobufMapping
from .channel_data import ChannelData
from .transport import NewTransport
LOG = logging.getLogger(__name__)
class ProtocolAndChannel:
def __init__(
self,
transport: NewTransport,
mapping: ProtobufMapping,
channel_keys: ChannelData | None = None,
) -> None:
self.transport = transport
self.mapping = mapping
self.channel_keys = channel_keys
def close(self) -> None: ...
# def write(self, session_id: bytes, msg: t.Any) -> None: ...
# def read(self, session_id: bytes) -> t.Any: ...
def get_channel_keys(self) -> ChannelData: ...
class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL")
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
print("wooooo")
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
print("wooooo")
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
class Channel:
id: int
channel_keys: ChannelData | None
def __init__(self, id: int, keys: ChannelData) -> None:
self.id = id
self.channel_keys = keys
def read(self) -> t.Any: ...
def write(self, msg: t.Any) -> None: ...

@ -0,0 +1,313 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
from enum import IntEnum
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import messages
from ...mapping import ProtobufMapping
from ..thp import checksum, curve25519, thp_io
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.packet_header import PacketHeader
from .channel_data import ChannelData
from .protocol_and_channel import Channel, ProtocolAndChannel
from .transport import NewTransport
LOG = logging.getLogger(__name__)
MANAGEMENT_SESSION_ID: int = 0
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(ProtocolAndChannel):
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
channel_id: int
sync_bit_send: int
sync_bit_receive: int
has_valid_channel: bool = False
def __init__(
self,
transport: NewTransport,
mapping: ProtobufMapping,
channel_keys: ChannelData | None = None,
) -> None:
super().__init__(transport, mapping, channel_keys)
self.channel: Channel | None = None
def get_channel(self) -> ProtocolV2:
if not self.has_valid_channel:
self._establish_new_channel()
return self
def read(self, session_id: int) -> t.Any:
header, data = self._read_until_valid_crc_check()
# TODO
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
def _establish_new_channel(self):
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Send channel allocation request
channel_id_request_nonce = os.urandom(8)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
PacketHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
)
# Read channel allocation response
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
):
print("TODO raise exception here, I guess")
self.channel_id = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
# Send handshake init request
ha_init_req_header = PacketHeader(0, self.channel_id, 36)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake init response
header, payload = self._read_until_valid_crc_check()
self._send_ack_1()
if not header.is_handshake_init_response():
print("Received message is not a valid handshake init response message")
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
print("noise_tag: ", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = PacketHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
print("Received message is not a valid handshake completion response")
self._send_ack_2()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
self._send_ack_1()
assert isinstance(maaa, messages.ThpEndResponse)
self.has_valid_channel = True
def _get_control_byte(self) -> bytes:
return b"\x42"
def _send_ack_1(self):
header = PacketHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_2(self):
header = PacketHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int = 0x04,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!")
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
int.to_bytes(session_id, 1, "big"),
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[PacketHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
print(hexlify(header.to_bytes_init() + payload + chksum))
LOG.debug("Received a message with invalid checksum")
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: PacketHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
print("Received message is not a channel allocation response")
return False
if len(payload) < 10:
print("Invalid channel allocation response payload")
return False
if payload[:8] != original_nonce:
print("Invalid channel allocation response payload (nonce mismatch)")
return False
return True
class ControlByteType(IntEnum):
CHANNEL_ALLOCATION_RES = 1
HANDSHAKE_INIT_RES = 2
HANDSHAKE_COMP_RES = 3
ACK = 4
ENCRYPTED_TRANSPORT = 5

@ -0,0 +1,65 @@
from __future__ import annotations
import typing as t
from ...messages import Features, Initialize
from .protocol_and_channel import ProtocolV1
from .protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from .client import NewTrezorClient
class Session:
features: Features
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
self.client = client
self.id = id
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
) -> Session:
raise NotImplementedError
def call(self, msg: t.Any) -> t.Any:
raise NotImplementedError
class SessionV1(Session):
@classmethod
def new(
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, b"")
cls.features = session.call(
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
Initialize()
)
session.id = cls.features.session_id
return session
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
# if should_reinit:
# self.initialize() # TODO
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
self.client.protocol.write(msg)
return self.client.protocol.read()
class SessionV2(Session):
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
self.channel = client.protocol.get_channel()
self.sid = self._convert_id_to_sid(id)
def call(self, msg: t.Any) -> t.Any:
self.channel.write(self.sid, msg)
return self.channel.read(self.sid)
def _convert_id_to_sid(self, id: bytes) -> int:
return int.from_bytes(id, "big") # TODO update to extract only sid

@ -0,0 +1,49 @@
from __future__ import annotations
import typing as t
from typing import TYPE_CHECKING, Iterable, Type, TypeVar
from ...exceptions import TrezorException
if TYPE_CHECKING:
from ...models import TrezorModel
T = TypeVar("T", bound="NewTransport")
class TransportException(TrezorException):
pass
class NewTransport:
PATH_PREFIX: str
@classmethod
def enumerate(
cls: Type["T"], models: Iterable["TrezorModel"] | None = None
) -> Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: ...
def open(self) -> None: ...
def close(self) -> None: ...
def write_chunk(self, chunk: bytes) -> None: ...
def read_chunk(self) -> bytes: ...
CHUNK_SIZE: t.ClassVar[int]

@ -0,0 +1,162 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import socket
import time
from typing import TYPE_CHECKING, Iterable, Tuple
from ...log import DUMP_PACKETS
from .. import TransportException
from .transport import NewTransport
if TYPE_CHECKING:
from ...models import TrezorModel
SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
class UdpTransport(NewTransport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(
self,
device: str | None = None,
) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
else:
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device: Tuple[str, int] = (host, port)
self.socket: socket.socket | None = None
super().__init__()
@classmethod
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
if d.ping():
return d
else:
raise TransportException(
f"No Trezor device found at address {d.get_path()}"
)
except Exception as e:
raise TransportException(f"Error opening {d.get_path()}") from e
finally:
d.close()
@classmethod
def enumerate(
cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
return [cls._try_path(default_path)]
except TransportException:
return []
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
try:
address = path.replace(f"{cls.PATH_PREFIX}:", "")
return cls._try_path(address)
except TransportException:
if not prefix_search:
raise
if prefix_search:
return super().find_by_path(path, prefix_search)
else:
raise TransportException(f"No UDP device at {path}")
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(SOCKET_TIMEOUT)
def close(self) -> None:
if self.socket is not None:
self.socket.close()
self.socket = None
def write_chunk(self, chunk: bytes) -> None:
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
assert self.socket is not None
while True:
try:
chunk = self.socket.recv(64)
break
except socket.timeout:
continue
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

@ -0,0 +1,167 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit
import logging
import sys
import time
from typing import Iterable, List
from ...log import DUMP_PACKETS
from ...models import TREZORS, TrezorModel
from .. import UDEV_RULES_STR, DeviceIsBusy, TransportException
from .transport import NewTransport
LOG = logging.getLogger(__name__)
try:
import usb1
USB_IMPORTED = True
except Exception as e:
LOG.warning(f"WebUSB transport is disabled: {e}")
USB_IMPORTED = False
INTERFACE = 0
ENDPOINT = 1
DEBUG_INTERFACE = 1
DEBUG_ENDPOINT = 2
class WebUsbTransport(NewTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.handle: usb1.USBDeviceHandle | None = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None:
self.handle = self.device.open()
if self.handle is None:
if sys.platform.startswith("linux"):
args = (UDEV_RULES_STR,)
else:
args = ()
raise IOError("Cannot open device", *args)
try:
self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None:
if self.handle is not None:
self.handle.releaseInterface(self.interface)
self.handle.close()
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
assert self.handle is not None
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
self.handle.interruptWrite(self.endpoint, chunk)
def read_chunk(self) -> bytes:
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
chunk = self.handle.interruptRead(endpoint, 64)
if chunk:
break
else:
time.sleep(0.001)
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
configurationId = 0
altSettingId = 0
return (
dev[configurationId][INTERFACE][altSettingId].getClass()
== usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
)
def dev_to_str(dev: "usb1.USBDevice") -> str:
return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
)

@ -174,14 +174,14 @@ class ProtocolBasedTransport(Transport):
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
self.handle.close()
if isinstance(response, messages.Failure):
from .protocol_v2 import ProtocolV2
from .protocol_v2 import DeprecatedProtocolV2
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"
):
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.handle)
protocol = DeprecatedProtocolV2(self.handle)
return protocol
@ -193,8 +193,8 @@ def _get_protocol(version: int, handle: Handle) -> Protocol:
return ProtocolV1(handle)
if version == PROTOCOL_VERSION_2:
from .protocol_v2 import ProtocolV2
from .protocol_v2 import DeprecatedProtocolV2
return ProtocolV2(handle)
return DeprecatedProtocolV2(handle)
raise NotImplementedError

@ -1,19 +1,12 @@
import hashlib
import hmac
import logging
import os
from binascii import hexlify
from enum import IntEnum
from typing import Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .. import messages
from ..mapping import ProtobufMapping
from ..protobuf import MessageType
from ..transport.protocol import Handle, Protocol
from .thp import checksum, curve25519, thp_io
from .thp.checksum import CHECKSUM_LENGTH
from .thp.packet_header import PacketHeader
LOG = logging.getLogger(__name__)
@ -40,7 +33,7 @@ def _get_iv_from_nonce(nonce: int) -> bytes:
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(Protocol):
class DeprecatedProtocolV2(Protocol):
def __init__(self, handle: Handle) -> None:
super().__init__(handle)
@ -50,191 +43,185 @@ class ProtocolV2(Protocol):
session_id: Optional[bytes] = None,
derive_caradano: Optional[bool] = None,
):
self.session_id: int = 0
self.sync_bit_send: int = 0
self.sync_bit_receive: int = 0
self.mapping = mapping
# Send channel allocation request
channel_id_request_nonce = os.urandom(8)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle,
PacketHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
)
# Read channel allocation response
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
):
print("TODO raise exception here, I guess")
self.cid = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
# Send handshake init request
ha_init_req_header = PacketHeader(0, self.cid, 36)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake init response
header, payload = self._read_until_valid_crc_check()
self._send_ack_1()
if not header.is_handshake_init_response():
print("Received message is not a valid handshake init response message")
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
print("noise_tag: ", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = PacketHeader(
0x12,
self.cid,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
print("Received message is not a valid handshake completion response")
self._send_ack_2()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request: int = 0
self.nonce_response: int = 1
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = mapping.encode(message)
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read
_, msg_type, msg_data = self.read_and_decrypt()
maaa = mapping.decode(msg_type, msg_data)
self._send_ack_1()
assert isinstance(maaa, messages.ThpEndResponse)
# Send get features
message = messages.GetFeatures()
message_type, message_data = mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
_ = thp_io.read(self.handle)
session_id, msg_type, msg_data = self.read_and_decrypt()
features = mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features)
features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
self._send_ack_2()
return features
# self.session_id: int = 0
# self.sync_bit_send: int = 0
# self.sync_bit_receive: int = 0
# self.mapping = mapping
# # Send channel allocation request
# channel_id_request_nonce = os.urandom(8)
# thp_io.write_payload_to_wire_and_add_checksum(
# self.handle,
# PacketHeader.get_channel_allocation_request_header(12),
# channel_id_request_nonce,
# )
# # Read channel allocation response
# header, payload = self._read_until_valid_crc_check()
# if not self._is_valid_channel_allocation_response(
# header, payload, channel_id_request_nonce
# ):
# print("TODO raise exception here, I guess")
# self.cid = int.from_bytes(payload[8:10], "big")
# self.device_properties = payload[10:]
# # Send handshake init request
# ha_init_req_header = PacketHeader(0, self.cid, 36)
# host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
# host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
# thp_io.write_payload_to_wire_and_add_checksum(
# self.handle, ha_init_req_header, host_ephemeral_pubkey
# )
# # Read ACK
# header, payload = self._read_until_valid_crc_check()
# if not header.is_ack() or len(payload) > 0:
# print("Received message is not a valid ACK ")
# # Read handshake init response
# header, payload = self._read_until_valid_crc_check()
# self._send_ack_1()
# if not header.is_handshake_init_response():
# print("Received message is not a valid handshake init response message")
# trezor_ephemeral_pubkey = payload[:32]
# encrypted_trezor_static_pubkey = payload[32:80]
# noise_tag = payload[80:96]
# # TODO check noise tag
# print("noise_tag: ", hexlify(noise_tag).decode())
# # Prepare and send handshake completion request
# PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
# IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
# IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
# h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
# h = _sha256_of_two(h, host_ephemeral_pubkey)
# h = _sha256_of_two(h, trezor_ephemeral_pubkey)
# ck, k = _hkdf(
# PROTOCOL_NAME,
# curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
# )
# aes_ctx = AESGCM(k)
# try:
# trezor_masked_static_pubkey = aes_ctx.decrypt(
# IV_1, encrypted_trezor_static_pubkey, h
# )
# # print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
# except Exception as e:
# print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
# h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
# ck, k = _hkdf(
# ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
# )
# aes_ctx = AESGCM(k)
# tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
# h = _sha256_of_two(h, tag_of_empty_string)
# # TODO: search for saved credentials (or possibly not, as we skip pairing phase)
# zeroes_32 = int.to_bytes(0, 32, "little")
# temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
# temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
# aes_ctx = AESGCM(k)
# encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
# h = _sha256_of_two(h, encrypted_host_static_pubkey)
# ck, k = _hkdf(
# ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
# )
# msg_data = mapping.encode_without_wire_type(
# messages.ThpHandshakeCompletionReqNoisePayload(
# pairing_methods=[
# messages.ThpPairingMethod.NoMethod,
# ]
# )
# )
# aes_ctx = AESGCM(k)
# encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
# h = _sha256_of_two(h, encrypted_payload)
# ha_completion_req_header = PacketHeader(
# 0x12,
# self.cid,
# len(encrypted_host_static_pubkey)
# + len(encrypted_payload)
# + CHECKSUM_LENGTH,
# )
# thp_io.write_payload_to_wire_and_add_checksum(
# self.handle,
# ha_completion_req_header,
# encrypted_host_static_pubkey + encrypted_payload,
# )
# # Read ACK
# header, payload = self._read_until_valid_crc_check()
# if not header.is_ack() or len(payload) > 0:
# print("Received message is not a valid ACK ")
# # Read handshake completion response, ignore payload as we do not care about the state
# header, _ = self._read_until_valid_crc_check()
# if not header.is_handshake_comp_response():
# print("Received message is not a valid handshake completion response")
# self._send_ack_2()
# self.key_request, self.key_response = _hkdf(ck, b"")
# self.nonce_request: int = 0
# self.nonce_response: int = 1
# # Send StartPairingReqest message
# message = messages.ThpStartPairingRequest()
# message_type, message_data = mapping.encode(message)
# self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
# # Read ACK
# header, payload = self._read_until_valid_crc_check()
# if not header.is_ack() or len(payload) > 0:
# print("Received message is not a valid ACK ")
# # Read
# _, msg_type, msg_data = self.read_and_decrypt()
# maaa = mapping.decode(msg_type, msg_data)
# self._send_ack_1()
# assert isinstance(maaa, messages.ThpEndResponse)
# # Send get features
# message = messages.GetFeatures()
# message_type, message_data = mapping.encode(message)
# self.session_id: int = 0
# self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
# _ = thp_io.read(self.handle)
# session_id, msg_type, msg_data = self.read_and_decrypt()
# features = mapping.decode(msg_type, msg_data)
# assert isinstance(features, messages.Features)
# features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
# self._send_ack_2()
# return features
...
def _encrypt_and_write(
self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
data = self.session_id.to_bytes(1, "big") + message_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader(
ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, header, encrypted_message
)
def _send_ack_1(self):
header = PacketHeader(0x20, self.cid, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
def _send_ack_2(self):
header = PacketHeader(0x28, self.cid, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
# assert self.key_request is not None
# aes_ctx = AESGCM(self.key_request)
# data = self.session_id.to_bytes(1, "big") + message_type + message_data
# nonce = _get_iv_from_nonce(self.nonce_request)
# self.nonce_request += 1
# encrypted_message = aes_ctx.encrypt(nonce, data, b"")
# header = PacketHeader(
# ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
# )
# thp_io.write_payload_to_wire_and_add_checksum(
# self.handle, header, encrypted_message
# )
...
def _write_message(self, message: MessageType, mapping: ProtobufMapping):
try:
@ -244,43 +231,46 @@ class ProtocolV2(Protocol):
print(type(e))
def write(self, message_type: int, message_data: bytes) -> None:
data = (
self.session_id.to_bytes(1, "big")
+ message_type.to_bytes(2, "big")
+ message_data
)
ctrl_byte = 0x04
self._write_and_encrypt(data, ctrl_byte)
# data = (
# self.session_id.to_bytes(1, "big")
# + message_type.to_bytes(2, "big")
# + message_data
# )
# ctrl_byte = 0x04
# self._write_and_encrypt(data, ctrl_byte)
...
def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None:
aes_ctx = AESGCM(self.key_request)
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_data = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader(
ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, header, encrypted_data
)
# aes_ctx = AESGCM(self.key_request)
# nonce = _get_iv_from_nonce(self.nonce_request)
# self.nonce_request += 1
# encrypted_data = aes_ctx.encrypt(nonce, data, b"")
# header = PacketHeader(
# ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
# )
# thp_io.write_payload_to_wire_and_add_checksum(
# self.handle, header, encrypted_data
# )
...
def read_and_decrypt(self) -> Tuple[bytes, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!")
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
int.to_bytes(session_id, 1, "big"),
int.from_bytes(message_type, "big"),
message_data,
)
# header, raw_payload = self._read_until_valid_crc_check()
# if not header.is_encrypted_transport():
# print("Trying to decrypt not encrypted message!")
# aes_ctx = AESGCM(self.key_response)
# nonce = _get_iv_from_nonce(self.nonce_response)
# self.nonce_response += 1
# message = aes_ctx.decrypt(nonce, raw_payload, b"")
# session_id = message[0]
# message_type = message[1:3]
# message_data = message[3:]
# return (
# int.to_bytes(session_id, 1, "big"),
# int.from_bytes(message_type, "big"),
# message_data,
# )
...
def end_session(self, session_id: bytes) -> None:
pass
@ -290,22 +280,24 @@ class ProtocolV2(Protocol):
return self.start_session("")
def start_session(self, passphrase: str) -> bytes:
try:
msg = messages.ThpCreateNewSession(passphrase=passphrase)
except Exception as e:
print(e)
print("s")
self._write_message(msg, self.mapping)
print("p")
response_type, response_data = self._read_until_valid_crc_check()
print(response_type, response_data)
return b""
# try:
# msg = messages.ThpCreateNewSession(passphrase=passphrase)
# except Exception as e:
# print(e)
# print("s")
# self._write_message(msg, self.mapping)
# print("p")
# response_type, response_data = self._read_until_valid_crc_check()
# print(response_type, response_data)
# return b""
...
def read(self) -> Tuple[int, bytes]:
header, raw_payload, chksum = thp_io.read(self.handle)
print("Read message", hexlify(raw_payload))
return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
# header, raw_payload, chksum = thp_io.read(self.handle)
# print("Read message", hexlify(raw_payload))
# return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
...
def _get_control_byte(self) -> bytes:
return b"\x42"
@ -313,16 +305,17 @@ class ProtocolV2(Protocol):
def _read_until_valid_crc_check(
self,
) -> Tuple[PacketHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.handle)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
print(hexlify(header.to_bytes_init() + payload + chksum))
LOG.debug("Received a message with invalid checksum")
header, payload, chksum = thp_io.read(self.handle)
return header, payload
# is_valid = False
# header, payload, chksum = thp_io.read(self.handle)
# while not is_valid:
# is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
# if not is_valid:
# print(hexlify(header.to_bytes_init() + payload + chksum))
# LOG.debug("Received a message with invalid checksum")
# header, payload, chksum = thp_io.read(self.handle)
# return header, payload
...
def _is_valid_channel_allocation_response(
self, header: PacketHeader, payload: bytes, original_nonce: bytes

@ -1,15 +1,12 @@
import struct
from binascii import hexlify
from typing import Tuple
from ..protocol import Handle
from ..new.transport import NewTransport
from ..thp import checksum
from .packet_header import PacketHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
PACKET_LENGTH = 64
CHECKSUM_LENGTH = 4
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
@ -17,48 +14,54 @@ CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
handle: Handle, header: PacketHeader, transport_payload: bytes
transport: NewTransport, header: PacketHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(handle, header, data)
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
handle: Handle, header: PacketHeader, transport_payload: bytes
transport: NewTransport, header: PacketHeader, transport_payload: bytes
):
handle.open()
transport.open()
buffer = bytearray(transport_payload)
chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH]
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
handle.write_chunk(chunk)
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :]
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH]
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
handle.write_chunk(chunk)
buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :]
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]:
def read(transport: NewTransport) -> Tuple[PacketHeader, bytes, bytes]:
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(handle)
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(handle, header.cid))
buffer.extend(read_next(transport, header.cid))
# print("buffer read (data):", hexlify(buffer).decode())
# print("buffer len (data):", datalen)
# TODO check checksum?? or do not strip ?
data_len = header.data_length - CHECKSUM_LENGTH
return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH]
data_len = header.data_length - checksum.CHECKSUM_LENGTH
return (
header,
buffer[:data_len],
buffer[data_len : data_len + checksum.CHECKSUM_LENGTH],
)
def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
chunk = handle.read_chunk()
def read_first(transport: NewTransport) -> Tuple[PacketHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
@ -70,8 +73,8 @@ def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
return PacketHeader(ctrl_byte, cid, data_length), data
def read_next(handle: Handle, cid: int) -> bytes:
chunk = handle.read_chunk()
def read_next(transport: NewTransport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)

Loading…
Cancel
Save