mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
wip trezorlib
This commit is contained in:
parent
87d6407d26
commit
11309cccd0
@ -24,9 +24,10 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
|
||||
|
||||
import click
|
||||
|
||||
from .. import __version__, log, messages, protobuf, ui
|
||||
from .. import __version__, log, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..transport import DeviceIsBusy, enumerate_devices
|
||||
from ..transport import DeviceIsBusy, new_enumerate_devices
|
||||
from ..transport.new.client import NewTrezorClient
|
||||
from ..transport.udp import UdpTransport
|
||||
from . import (
|
||||
AliasedGroup,
|
||||
@ -54,7 +55,7 @@ from . import (
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..transport import Transport
|
||||
from ..transport.new.transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -281,16 +282,18 @@ def format_device_name(features: messages.Features) -> str:
|
||||
|
||||
@cli.command(name="list")
|
||||
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
return enumerate_devices()
|
||||
return new_enumerate_devices()
|
||||
|
||||
for transport in enumerate_devices():
|
||||
for transport in new_enumerate_devices():
|
||||
try:
|
||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||
description = format_device_name(client.features)
|
||||
client.end_session()
|
||||
print("test A")
|
||||
client = NewTrezorClient(transport)
|
||||
session = client.get_session()
|
||||
description = format_device_name(session.features)
|
||||
# client.end_session()
|
||||
print("after end session")
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
|
@ -14,6 +14,8 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -1096,6 +1098,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
if not hasattr(input_flow, "send"):
|
||||
raise RuntimeError("input_flow should be a generator function")
|
||||
self.ui.input_flow = input_flow
|
||||
assert input_flow is not None
|
||||
input_flow.send(None) # start the generator
|
||||
|
||||
def watch_layout(self, watch: bool = True) -> None:
|
||||
|
@ -66,7 +66,6 @@ class ProtobufMapping:
|
||||
print("wire type", wire_type)
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
print("test")
|
||||
return wire_type, buf.getvalue()
|
||||
|
||||
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||
|
@ -14,17 +14,10 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Type, TypeVar
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
from ..mapping import ProtobufMapping
|
||||
@ -82,8 +75,8 @@ class Transport:
|
||||
def initialize_connection(
|
||||
self,
|
||||
mapping: "ProtobufMapping",
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
session_id: bytes | None = None,
|
||||
derive_cardano: bool | None = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -113,7 +106,7 @@ class Transport:
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls: Type["T"], models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -145,8 +138,21 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
return set(t for t in transports if t.ENABLED)
|
||||
|
||||
|
||||
def all_new_transports() -> Iterable[Type["NewTransport"]]:
|
||||
# from .bridge import BridgeTransport
|
||||
# from .hid import HidTransport
|
||||
from .new.udp import UdpTransport
|
||||
from .new.webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["NewTransport"], ...] = (
|
||||
UdpTransport,
|
||||
WebUsbTransport,
|
||||
)
|
||||
return set(t for t in transports if t.ENABLED)
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
models: Iterable["TrezorModel"] | None = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
@ -163,9 +169,28 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
from .new.transport import NewTransport
|
||||
|
||||
|
||||
def new_enumerate_devices(
|
||||
models: Iterable["TrezorModel"] | None = None,
|
||||
) -> Sequence["NewTransport"]:
|
||||
devices: List["NewTransport"] = []
|
||||
for transport in all_new_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
found = list(transport.enumerate(models))
|
||||
LOG.info(f"Enumerating {name}: found {len(found)} devices")
|
||||
devices.extend(found)
|
||||
except NotImplementedError:
|
||||
LOG.error(f"{name} does not implement device enumeration")
|
||||
except Exception as e:
|
||||
excname = e.__class__.__name__
|
||||
LOG.error(f"Failed to enumerate {name}. {excname}: {e}")
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
11
python/src/trezorlib/transport/new/channel_data.py
Normal file
11
python/src/trezorlib/transport/new/channel_data.py
Normal file
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class ChannelData:
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: bytes
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
76
python/src/trezorlib/transport/new/client.py
Normal file
76
python/src/trezorlib/transport/new/client.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import mapping
|
||||
from ...mapping import ProtobufMapping
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel, ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .session import Session, SessionV1, SessionV2
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NewTrezorClient:
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
protocol: ProtocolAndChannel | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
|
||||
if protobuf_mapping is None:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
print("test B")
|
||||
|
||||
if protocol is None:
|
||||
print("test C")
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
cls, transport: NewTransport, channel_data: ChannelData
|
||||
) -> NewTrezorClient: ...
|
||||
|
||||
def get_session(
|
||||
self, passphrase: str = "", derive_cardano: bool = False
|
||||
) -> Session:
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
return SessionV1.new(self, passphrase, derive_cardano)
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
return SessionV2.new(self, passphrase, derive_cardano)
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def resume_session(self, session_id: bytes) -> Session:
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def _get_protocol(self) -> ProtocolAndChannel:
|
||||
|
||||
from ... import mapping, messages
|
||||
from ...messages import FailureType
|
||||
from .protocol_and_channel import ProtocolV1
|
||||
|
||||
self.transport.open()
|
||||
|
||||
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
|
||||
|
||||
protocol.write(messages.Initialize())
|
||||
|
||||
response = protocol.read()
|
||||
self.transport.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
print("test F1")
|
||||
if (
|
||||
response.code == FailureType.UnexpectedMessage
|
||||
and response.message == "Invalid protocol"
|
||||
):
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol = ProtocolV2(self.transport, self.mapping)
|
||||
return protocol
|
117
python/src/trezorlib/transport/new/protocol_and_channel.py
Normal file
117
python/src/trezorlib/transport/new/protocol_and_channel.py
Normal file
@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ...log import DUMP_BYTES
|
||||
from ...mapping import ProtobufMapping
|
||||
from .channel_data import ChannelData
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolAndChannel:
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
self.channel_keys = channel_keys
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
# def write(self, session_id: bytes, msg: t.Any) -> None: ...
|
||||
|
||||
# def read(self, session_id: bytes) -> t.Any: ...
|
||||
|
||||
def get_channel_keys(self) -> ChannelData: ...
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
print("wooooo")
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
print("wooooo")
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self) -> t.Tuple[int, bytes]:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class Channel:
|
||||
id: int
|
||||
channel_keys: ChannelData | None
|
||||
|
||||
def __init__(self, id: int, keys: ChannelData) -> None:
|
||||
self.id = id
|
||||
self.channel_keys = keys
|
||||
|
||||
def read(self) -> t.Any: ...
|
||||
def write(self, msg: t.Any) -> None: ...
|
313
python/src/trezorlib/transport/new/protocol_v2.py
Normal file
313
python/src/trezorlib/transport/new/protocol_v2.py
Normal file
@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.packet_header import PacketHeader
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import Channel, ProtocolAndChannel
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MANAGEMENT_SESSION_ID: int = 0
|
||||
|
||||
|
||||
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||
hash = hashlib.sha256(val_1)
|
||||
hash.update(val_2)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
def _hkdf(chaining_key: bytes, input: bytes):
|
||||
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||
ctx_output_2.update(b"\x02")
|
||||
output_2 = ctx_output_2.digest()
|
||||
return (output_1, output_2)
|
||||
|
||||
|
||||
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||
raise ValueError("Nonce overflow, terminate the channel")
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
||||
|
||||
|
||||
class ProtocolV2(ProtocolAndChannel):
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
channel_id: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
has_valid_channel: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: NewTransport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_keys: ChannelData | None = None,
|
||||
) -> None:
|
||||
super().__init__(transport, mapping, channel_keys)
|
||||
self.channel: Channel | None = None
|
||||
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self.has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
return self
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
header, data = self._read_until_valid_crc_check()
|
||||
# TODO
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte
|
||||
|
||||
def _establish_new_channel(self):
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
# Send channel allocation request
|
||||
channel_id_request_nonce = os.urandom(8)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
PacketHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
)
|
||||
|
||||
# Read channel allocation response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
):
|
||||
print("TODO raise exception here, I guess")
|
||||
|
||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
|
||||
# Send handshake init request
|
||||
ha_init_req_header = PacketHeader(0, self.channel_id, 36)
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake init response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_1()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
print("Received message is not a valid handshake init response message")
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
print("noise_tag: ", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
ck, k = _hkdf(
|
||||
PROTOCOL_NAME,
|
||||
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
try:
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||
except Exception as e:
|
||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
)
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||
)
|
||||
msg_data = self.mapping.encode_without_wire_type(
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
pairing_methods=[
|
||||
messages.ThpPairingMethod.NoMethod,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload)
|
||||
ha_completion_req_header = PacketHeader(
|
||||
0x12,
|
||||
self.channel_id,
|
||||
len(encrypted_host_static_pubkey)
|
||||
+ len(encrypted_payload)
|
||||
+ CHECKSUM_LENGTH,
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
print("Received message is not a valid handshake completion response")
|
||||
self._send_ack_2()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
# Send StartPairingReqest message
|
||||
message = messages.ThpStartPairingRequest()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
|
||||
# Read
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
maaa = self.mapping.decode(msg_type, msg_data)
|
||||
self._send_ack_1()
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
self.has_valid_channel = True
|
||||
|
||||
def _get_control_byte(self) -> bytes:
|
||||
return b"\x42"
|
||||
|
||||
def _send_ack_1(self):
|
||||
header = PacketHeader(0x20, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _send_ack_2(self):
|
||||
header = PacketHeader(0x28, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _encrypt_and_write(
|
||||
self,
|
||||
session_id: int,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
ctrl_byte: int = 0x04,
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
|
||||
sid = session_id.to_bytes(1, "big")
|
||||
msg_type = message_type.to_bytes(2, "big")
|
||||
data = sid + msg_type + message_data
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = PacketHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, header, encrypted_message
|
||||
)
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
aes_ctx = AESGCM(self.key_response)
|
||||
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
self.nonce_response += 1
|
||||
|
||||
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
int.to_bytes(session_id, 1, "big"),
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> t.Tuple[PacketHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||
LOG.debug("Received a message with invalid checksum")
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
|
||||
return header, payload
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: PacketHeader, payload: bytes, original_nonce: bytes
|
||||
) -> bool:
|
||||
if not header.is_channel_allocation_response():
|
||||
print("Received message is not a channel allocation response")
|
||||
return False
|
||||
if len(payload) < 10:
|
||||
print("Invalid channel allocation response payload")
|
||||
return False
|
||||
if payload[:8] != original_nonce:
|
||||
print("Invalid channel allocation response payload (nonce mismatch)")
|
||||
return False
|
||||
return True
|
||||
|
||||
class ControlByteType(IntEnum):
|
||||
CHANNEL_ALLOCATION_RES = 1
|
||||
HANDSHAKE_INIT_RES = 2
|
||||
HANDSHAKE_COMP_RES = 3
|
||||
ACK = 4
|
||||
ENCRYPTED_TRANSPORT = 5
|
65
python/src/trezorlib/transport/new/session.py
Normal file
65
python/src/trezorlib/transport/new/session.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...messages import Features, Initialize
|
||||
from .protocol_and_channel import ProtocolV1
|
||||
from .protocol_v2 import ProtocolV2
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .client import NewTrezorClient
|
||||
|
||||
|
||||
class Session:
|
||||
features: Features
|
||||
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
self.client = client
|
||||
self.id = id
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
) -> Session:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, b"")
|
||||
cls.features = session.call(
|
||||
# Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO
|
||||
Initialize()
|
||||
)
|
||||
session.id = cls.features.session_id
|
||||
return session
|
||||
|
||||
def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any:
|
||||
# if should_reinit:
|
||||
# self.initialize() # TODO
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
self.client.protocol.write(msg)
|
||||
return self.client.protocol.read()
|
||||
|
||||
|
||||
class SessionV2(Session):
|
||||
def __init__(self, client: NewTrezorClient, id: bytes) -> None:
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
self.channel = client.protocol.get_channel()
|
||||
self.sid = self._convert_id_to_sid(id)
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
self.channel.write(self.sid, msg)
|
||||
return self.channel.read(self.sid)
|
||||
|
||||
def _convert_id_to_sid(self, id: bytes) -> int:
|
||||
return int.from_bytes(id, "big") # TODO update to extract only sid
|
49
python/src/trezorlib/transport/new/transport.py
Normal file
49
python/src/trezorlib/transport/new/transport.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from typing import TYPE_CHECKING, Iterable, Type, TypeVar
|
||||
|
||||
from ...exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="NewTransport")
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
pass
|
||||
|
||||
|
||||
class NewTransport:
|
||||
PATH_PREFIX: str
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str: ...
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int]
|
162
python/src/trezorlib/transport/new/udp.py
Normal file
162
python/src/trezorlib/transport/new/udp.py
Normal file
@ -0,0 +1,162 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ...log import DUMP_PACKETS
|
||||
from .. import TransportException
|
||||
from .transport import NewTransport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models import TrezorModel
|
||||
|
||||
SOCKET_TIMEOUT = 10
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(NewTransport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
else:
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
self.socket: socket.socket | None = None
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
f"No Trezor device found at address {d.get_path()}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise TransportException(f"Error opening {d.get_path()}") from e
|
||||
|
||||
finally:
|
||||
d.close()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
return [cls._try_path(default_path)]
|
||||
except TransportException:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
||||
try:
|
||||
address = path.replace(f"{cls.PATH_PREFIX}:", "")
|
||||
return cls._try_path(address)
|
||||
except TransportException:
|
||||
if not prefix_search:
|
||||
raise
|
||||
|
||||
if prefix_search:
|
||||
return super().find_by_path(path, prefix_search)
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.connect(self.device)
|
||||
self.socket.settimeout(SOCKET_TIMEOUT)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
chunk = self.socket.recv(64)
|
||||
break
|
||||
except socket.timeout:
|
||||
continue
|
||||
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
167
python/src/trezorlib/transport/new/webusb.py
Normal file
167
python/src/trezorlib/transport/new/webusb.py
Normal file
@ -0,0 +1,167 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List
|
||||
|
||||
from ...log import DUMP_PACKETS
|
||||
from ...models import TREZORS, TrezorModel
|
||||
from .. import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from .transport import NewTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import usb1
|
||||
|
||||
USB_IMPORTED = True
|
||||
except Exception as e:
|
||||
LOG.warning(f"WebUSB transport is disabled: {e}")
|
||||
USB_IMPORTED = False
|
||||
|
||||
INTERFACE = 0
|
||||
ENDPOINT = 1
|
||||
DEBUG_INTERFACE = 1
|
||||
DEBUG_ENDPOINT = 2
|
||||
|
||||
|
||||
class WebUsbTransport(NewTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
if sys.platform.startswith("linux"):
|
||||
args = (UDEV_RULES_STR,)
|
||||
else:
|
||||
args = ()
|
||||
raise IOError("Cannot open device", *args)
|
||||
try:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
self.handle.releaseInterface(self.interface)
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.handle is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
||||
self.handle.interruptWrite(self.endpoint, chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
chunk = self.handle.interruptRead(endpoint, 64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
||||
|
||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||
configurationId = 0
|
||||
altSettingId = 0
|
||||
return (
|
||||
dev[configurationId][INTERFACE][altSettingId].getClass()
|
||||
== usb1.libusb1.LIBUSB_CLASS_VENDOR_SPEC
|
||||
)
|
||||
|
||||
|
||||
def dev_to_str(dev: "usb1.USBDevice") -> str:
|
||||
return ":".join(
|
||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||
)
|
@ -174,14 +174,14 @@ class ProtocolBasedTransport(Transport):
|
||||
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
self.handle.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .protocol_v2 import DeprecatedProtocolV2
|
||||
|
||||
if (
|
||||
response.code == FailureType.UnexpectedMessage
|
||||
and response.message == "Invalid protocol"
|
||||
):
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol = ProtocolV2(self.handle)
|
||||
protocol = DeprecatedProtocolV2(self.handle)
|
||||
|
||||
return protocol
|
||||
|
||||
@ -193,8 +193,8 @@ def _get_protocol(version: int, handle: Handle) -> Protocol:
|
||||
return ProtocolV1(handle)
|
||||
|
||||
if version == PROTOCOL_VERSION_2:
|
||||
from .protocol_v2 import ProtocolV2
|
||||
from .protocol_v2 import DeprecatedProtocolV2
|
||||
|
||||
return ProtocolV2(handle)
|
||||
return DeprecatedProtocolV2(handle)
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -1,19 +1,12 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from .. import messages
|
||||
from ..mapping import ProtobufMapping
|
||||
from ..protobuf import MessageType
|
||||
from ..transport.protocol import Handle, Protocol
|
||||
from .thp import checksum, curve25519, thp_io
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.packet_header import PacketHeader
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -40,7 +33,7 @@ def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
||||
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
class DeprecatedProtocolV2(Protocol):
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
super().__init__(handle)
|
||||
|
||||
@ -50,191 +43,185 @@ class ProtocolV2(Protocol):
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_caradano: Optional[bool] = None,
|
||||
):
|
||||
self.session_id: int = 0
|
||||
self.sync_bit_send: int = 0
|
||||
self.sync_bit_receive: int = 0
|
||||
self.mapping = mapping
|
||||
# Send channel allocation request
|
||||
channel_id_request_nonce = os.urandom(8)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.handle,
|
||||
PacketHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
)
|
||||
# self.session_id: int = 0
|
||||
# self.sync_bit_send: int = 0
|
||||
# self.sync_bit_receive: int = 0
|
||||
# self.mapping = mapping
|
||||
# # Send channel allocation request
|
||||
# channel_id_request_nonce = os.urandom(8)
|
||||
# thp_io.write_payload_to_wire_and_add_checksum(
|
||||
# self.handle,
|
||||
# PacketHeader.get_channel_allocation_request_header(12),
|
||||
# channel_id_request_nonce,
|
||||
# )
|
||||
|
||||
# Read channel allocation response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
):
|
||||
print("TODO raise exception here, I guess")
|
||||
# # Read channel allocation response
|
||||
# header, payload = self._read_until_valid_crc_check()
|
||||
# if not self._is_valid_channel_allocation_response(
|
||||
# header, payload, channel_id_request_nonce
|
||||
# ):
|
||||
# print("TODO raise exception here, I guess")
|
||||
|
||||
self.cid = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
# self.cid = int.from_bytes(payload[8:10], "big")
|
||||
# self.device_properties = payload[10:]
|
||||
|
||||
# Send handshake init request
|
||||
ha_init_req_header = PacketHeader(0, self.cid, 36)
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
# # Send handshake init request
|
||||
# ha_init_req_header = PacketHeader(0, self.cid, 36)
|
||||
# host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
# host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.handle, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
# thp_io.write_payload_to_wire_and_add_checksum(
|
||||
# self.handle, ha_init_req_header, host_ephemeral_pubkey
|
||||
# )
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
# # Read ACK
|
||||
# header, payload = self._read_until_valid_crc_check()
|
||||
# if not header.is_ack() or len(payload) > 0:
|
||||
# print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake init response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_1()
|
||||
# # Read handshake init response
|
||||
# header, payload = self._read_until_valid_crc_check()
|
||||
# self._send_ack_1()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
print("Received message is not a valid handshake init response message")
|
||||
# if not header.is_handshake_init_response():
|
||||
# print("Received message is not a valid handshake init response message")
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
noise_tag = payload[80:96]
|
||||
# trezor_ephemeral_pubkey = payload[:32]
|
||||
# encrypted_trezor_static_pubkey = payload[32:80]
|
||||
# noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
print("noise_tag: ", hexlify(noise_tag).decode())
|
||||
# # TODO check noise tag
|
||||
# print("noise_tag: ", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
ck, k = _hkdf(
|
||||
PROTOCOL_NAME,
|
||||
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
)
|
||||
# # Prepare and send handshake completion request
|
||||
# PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
# IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
# IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
# h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
# h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
# h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
# ck, k = _hkdf(
|
||||
# PROTOCOL_NAME,
|
||||
# curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
# )
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
try:
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||
except Exception as e:
|
||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
)
|
||||
aes_ctx = AESGCM(k)
|
||||
# aes_ctx = AESGCM(k)
|
||||
# try:
|
||||
# trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
# IV_1, encrypted_trezor_static_pubkey, h
|
||||
# )
|
||||
# # print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||
# except Exception as e:
|
||||
# print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
# h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
# ck, k = _hkdf(
|
||||
# ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
# )
|
||||
# aes_ctx = AESGCM(k)
|
||||
|
||||
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||
# tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
# h = _sha256_of_two(h, tag_of_empty_string)
|
||||
# # TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||
)
|
||||
msg_data = mapping.encode_without_wire_type(
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
pairing_methods=[
|
||||
messages.ThpPairingMethod.NoMethod,
|
||||
]
|
||||
)
|
||||
)
|
||||
# zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
# temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
# temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||
# aes_ctx = AESGCM(k)
|
||||
# encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||
# h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
# ck, k = _hkdf(
|
||||
# ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||
# )
|
||||
# msg_data = mapping.encode_without_wire_type(
|
||||
# messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
# pairing_methods=[
|
||||
# messages.ThpPairingMethod.NoMethod,
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
# aes_ctx = AESGCM(k)
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload)
|
||||
ha_completion_req_header = PacketHeader(
|
||||
0x12,
|
||||
self.cid,
|
||||
len(encrypted_host_static_pubkey)
|
||||
+ len(encrypted_payload)
|
||||
+ CHECKSUM_LENGTH,
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.handle,
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
# encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
# h = _sha256_of_two(h, encrypted_payload)
|
||||
# ha_completion_req_header = PacketHeader(
|
||||
# 0x12,
|
||||
# self.cid,
|
||||
# len(encrypted_host_static_pubkey)
|
||||
# + len(encrypted_payload)
|
||||
# + CHECKSUM_LENGTH,
|
||||
# )
|
||||
# thp_io.write_payload_to_wire_and_add_checksum(
|
||||
# self.handle,
|
||||
# ha_completion_req_header,
|
||||
# encrypted_host_static_pubkey + encrypted_payload,
|
||||
# )
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
# # Read ACK
|
||||
# header, payload = self._read_until_valid_crc_check()
|
||||
# if not header.is_ack() or len(payload) > 0:
|
||||
# print("Received message is not a valid ACK ")
|
||||
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
print("Received message is not a valid handshake completion response")
|
||||
self._send_ack_2()
|
||||
# # Read handshake completion response, ignore payload as we do not care about the state
|
||||
# header, _ = self._read_until_valid_crc_check()
|
||||
# if not header.is_handshake_comp_response():
|
||||
# print("Received message is not a valid handshake completion response")
|
||||
# self._send_ack_2()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request: int = 0
|
||||
self.nonce_response: int = 1
|
||||
# self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
# self.nonce_request: int = 0
|
||||
# self.nonce_response: int = 1
|
||||
|
||||
# Send StartPairingReqest message
|
||||
message = messages.ThpStartPairingRequest()
|
||||
message_type, message_data = mapping.encode(message)
|
||||
# # Send StartPairingReqest message
|
||||
# message = messages.ThpStartPairingRequest()
|
||||
# message_type, message_data = mapping.encode(message)
|
||||
|
||||
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
|
||||
# self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
# # Read ACK
|
||||
# header, payload = self._read_until_valid_crc_check()
|
||||
# if not header.is_ack() or len(payload) > 0:
|
||||
# print("Received message is not a valid ACK ")
|
||||
|
||||
# Read
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
maaa = mapping.decode(msg_type, msg_data)
|
||||
self._send_ack_1()
|
||||
# # Read
|
||||
# _, msg_type, msg_data = self.read_and_decrypt()
|
||||
# maaa = mapping.decode(msg_type, msg_data)
|
||||
# self._send_ack_1()
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
# assert isinstance(maaa, messages.ThpEndResponse)
|
||||
|
||||
# Send get features
|
||||
message = messages.GetFeatures()
|
||||
message_type, message_data = mapping.encode(message)
|
||||
# # Send get features
|
||||
# message = messages.GetFeatures()
|
||||
# message_type, message_data = mapping.encode(message)
|
||||
|
||||
self.session_id: int = 0
|
||||
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
|
||||
_ = thp_io.read(self.handle)
|
||||
session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||
features = mapping.decode(msg_type, msg_data)
|
||||
assert isinstance(features, messages.Features)
|
||||
features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
|
||||
self._send_ack_2()
|
||||
return features
|
||||
# self.session_id: int = 0
|
||||
# self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
|
||||
# _ = thp_io.read(self.handle)
|
||||
# session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||
# features = mapping.decode(msg_type, msg_data)
|
||||
# assert isinstance(features, messages.Features)
|
||||
# features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
|
||||
# self._send_ack_2()
|
||||
# return features
|
||||
...
|
||||
|
||||
def _encrypt_and_write(
|
||||
self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
data = self.session_id.to_bytes(1, "big") + message_type + message_data
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = PacketHeader(
|
||||
ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
# assert self.key_request is not None
|
||||
# aes_ctx = AESGCM(self.key_request)
|
||||
# data = self.session_id.to_bytes(1, "big") + message_type + message_data
|
||||
# nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
# self.nonce_request += 1
|
||||
# encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
# header = PacketHeader(
|
||||
# ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
# )
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.handle, header, encrypted_message
|
||||
)
|
||||
|
||||
def _send_ack_1(self):
|
||||
header = PacketHeader(0x20, self.cid, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
|
||||
|
||||
def _send_ack_2(self):
|
||||
header = PacketHeader(0x28, self.cid, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
|
||||
# thp_io.write_payload_to_wire_and_add_checksum(
|
||||
# self.handle, header, encrypted_message
|
||||
# )
|
||||
...
|
||||
|
||||
def _write_message(self, message: MessageType, mapping: ProtobufMapping):
|
||||
try:
|
||||
@ -244,43 +231,46 @@ class ProtocolV2(Protocol):
|
||||
print(type(e))
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
data = (
|
||||
self.session_id.to_bytes(1, "big")
|
||||
+ message_type.to_bytes(2, "big")
|
||||
+ message_data
|
||||
)
|
||||
ctrl_byte = 0x04
|
||||
self._write_and_encrypt(data, ctrl_byte)
|
||||
# data = (
|
||||
# self.session_id.to_bytes(1, "big")
|
||||
# + message_type.to_bytes(2, "big")
|
||||
# + message_data
|
||||
# )
|
||||
# ctrl_byte = 0x04
|
||||
# self._write_and_encrypt(data, ctrl_byte)
|
||||
...
|
||||
|
||||
def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None:
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_data = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = PacketHeader(
|
||||
ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.handle, header, encrypted_data
|
||||
)
|
||||
# aes_ctx = AESGCM(self.key_request)
|
||||
# nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
# self.nonce_request += 1
|
||||
# encrypted_data = aes_ctx.encrypt(nonce, data, b"")
|
||||
# header = PacketHeader(
|
||||
# ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
|
||||
# )
|
||||
# thp_io.write_payload_to_wire_and_add_checksum(
|
||||
# self.handle, header, encrypted_data
|
||||
# )
|
||||
...
|
||||
|
||||
def read_and_decrypt(self) -> Tuple[bytes, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
aes_ctx = AESGCM(self.key_response)
|
||||
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
self.nonce_response += 1
|
||||
# header, raw_payload = self._read_until_valid_crc_check()
|
||||
# if not header.is_encrypted_transport():
|
||||
# print("Trying to decrypt not encrypted message!")
|
||||
# aes_ctx = AESGCM(self.key_response)
|
||||
# nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
# self.nonce_response += 1
|
||||
|
||||
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
int.to_bytes(session_id, 1, "big"),
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
# message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
# session_id = message[0]
|
||||
# message_type = message[1:3]
|
||||
# message_data = message[3:]
|
||||
# return (
|
||||
# int.to_bytes(session_id, 1, "big"),
|
||||
# int.from_bytes(message_type, "big"),
|
||||
# message_data,
|
||||
# )
|
||||
...
|
||||
|
||||
def end_session(self, session_id: bytes) -> None:
|
||||
pass
|
||||
@ -290,22 +280,24 @@ class ProtocolV2(Protocol):
|
||||
return self.start_session("")
|
||||
|
||||
def start_session(self, passphrase: str) -> bytes:
|
||||
try:
|
||||
msg = messages.ThpCreateNewSession(passphrase=passphrase)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("s")
|
||||
# try:
|
||||
# msg = messages.ThpCreateNewSession(passphrase=passphrase)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# print("s")
|
||||
|
||||
self._write_message(msg, self.mapping)
|
||||
print("p")
|
||||
response_type, response_data = self._read_until_valid_crc_check()
|
||||
print(response_type, response_data)
|
||||
return b""
|
||||
# self._write_message(msg, self.mapping)
|
||||
# print("p")
|
||||
# response_type, response_data = self._read_until_valid_crc_check()
|
||||
# print(response_type, response_data)
|
||||
# return b""
|
||||
...
|
||||
|
||||
def read(self) -> Tuple[int, bytes]:
|
||||
header, raw_payload, chksum = thp_io.read(self.handle)
|
||||
print("Read message", hexlify(raw_payload))
|
||||
return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
|
||||
# header, raw_payload, chksum = thp_io.read(self.handle)
|
||||
# print("Read message", hexlify(raw_payload))
|
||||
# return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
|
||||
...
|
||||
|
||||
def _get_control_byte(self) -> bytes:
|
||||
return b"\x42"
|
||||
@ -313,16 +305,17 @@ class ProtocolV2(Protocol):
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> Tuple[PacketHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.handle)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||
LOG.debug("Received a message with invalid checksum")
|
||||
header, payload, chksum = thp_io.read(self.handle)
|
||||
# is_valid = False
|
||||
# header, payload, chksum = thp_io.read(self.handle)
|
||||
# while not is_valid:
|
||||
# is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
# if not is_valid:
|
||||
# print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||
# LOG.debug("Received a message with invalid checksum")
|
||||
# header, payload, chksum = thp_io.read(self.handle)
|
||||
|
||||
return header, payload
|
||||
# return header, payload
|
||||
...
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: PacketHeader, payload: bytes, original_nonce: bytes
|
||||
|
@ -1,15 +1,12 @@
|
||||
import struct
|
||||
from binascii import hexlify
|
||||
from typing import Tuple
|
||||
|
||||
from ..protocol import Handle
|
||||
from ..new.transport import NewTransport
|
||||
from ..thp import checksum
|
||||
from .packet_header import PacketHeader
|
||||
|
||||
INIT_HEADER_LENGTH = 5
|
||||
CONT_HEADER_LENGTH = 3
|
||||
PACKET_LENGTH = 64
|
||||
CHECKSUM_LENGTH = 4
|
||||
MAX_PAYLOAD_LEN = 60000
|
||||
MESSAGE_TYPE_LENGTH = 2
|
||||
|
||||
@ -17,48 +14,54 @@ CONTINUATION_PACKET = 0x80
|
||||
|
||||
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
handle: Handle, header: PacketHeader, transport_payload: bytes
|
||||
transport: NewTransport, header: PacketHeader, transport_payload: bytes
|
||||
):
|
||||
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||
data = transport_payload + chksum
|
||||
write_payload_to_wire(handle, header, data)
|
||||
write_payload_to_wire(transport, header, data)
|
||||
|
||||
|
||||
def write_payload_to_wire(
|
||||
handle: Handle, header: PacketHeader, transport_payload: bytes
|
||||
transport: NewTransport, header: PacketHeader, transport_payload: bytes
|
||||
):
|
||||
handle.open()
|
||||
transport.open()
|
||||
buffer = bytearray(transport_payload)
|
||||
chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH]
|
||||
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
|
||||
handle.write_chunk(chunk)
|
||||
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
|
||||
buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :]
|
||||
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
|
||||
while buffer:
|
||||
chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH]
|
||||
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
|
||||
handle.write_chunk(chunk)
|
||||
buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :]
|
||||
chunk = (
|
||||
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
|
||||
)
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||
|
||||
|
||||
def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]:
|
||||
def read(transport: NewTransport) -> Tuple[PacketHeader, bytes, bytes]:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
header, first_chunk = read_first(handle)
|
||||
header, first_chunk = read_first(transport)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < header.data_length:
|
||||
buffer.extend(read_next(handle, header.cid))
|
||||
buffer.extend(read_next(transport, header.cid))
|
||||
# print("buffer read (data):", hexlify(buffer).decode())
|
||||
# print("buffer len (data):", datalen)
|
||||
# TODO check checksum?? or do not strip ?
|
||||
data_len = header.data_length - CHECKSUM_LENGTH
|
||||
return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH]
|
||||
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||
return (
|
||||
header,
|
||||
buffer[:data_len],
|
||||
buffer[data_len : data_len + checksum.CHECKSUM_LENGTH],
|
||||
)
|
||||
|
||||
|
||||
def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
|
||||
chunk = handle.read_chunk()
|
||||
def read_first(transport: NewTransport) -> Tuple[PacketHeader, bytes]:
|
||||
chunk = transport.read_chunk()
|
||||
try:
|
||||
ctrl_byte, cid, data_length = struct.unpack(
|
||||
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||
@ -70,8 +73,8 @@ def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
|
||||
return PacketHeader(ctrl_byte, cid, data_length), data
|
||||
|
||||
|
||||
def read_next(handle: Handle, cid: int) -> bytes:
|
||||
chunk = handle.read_chunk()
|
||||
def read_next(transport: NewTransport, cid: int) -> bytes:
|
||||
chunk = transport.read_chunk()
|
||||
ctrl_byte, read_cid = struct.unpack(
|
||||
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user