1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

TEMPORARY WIP - trezorlib

This commit is contained in:
M1nd3r 2024-08-30 10:56:25 +02:00
parent 06cc68cc46
commit f6ff8529c6
12 changed files with 562 additions and 10 deletions

View File

@ -653,6 +653,7 @@ def update(
against downloaded firmware fingerprint. Otherwise fingerprint is checked against downloaded firmware fingerprint. Otherwise fingerprint is checked
against data.trezor.io information, if available. against data.trezor.io information, if available.
""" """
print("client context")
with obj.client_context() as client: with obj.client_context() as client:
if sum(bool(x) for x in (filename, url, version)) > 1: if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.") click.echo("You can use only one of: filename, url, version.")

View File

@ -291,6 +291,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
client = TrezorClient(transport, ui=ui.ClickUI()) client = TrezorClient(transport, ui=ui.ClickUI())
description = format_device_name(client.features) description = format_device_name(client.features)
client.end_session() client.end_session()
print("after end session")
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception:

View File

@ -134,12 +134,24 @@ class TrezorClient(Generic[UI]):
self.session_id = session_id self.session_id = session_id
if _init_device: if _init_device:
self.init_device(session_id=session_id, derive_cardano=derive_cardano) self.init_device(session_id=session_id, derive_cardano=derive_cardano)
self.resume_session()
def open(self) -> None: def open(self) -> None:
if self.session_counter == 0: if self.session_counter == 0:
session_id = self.transport.resume_session(b"")
if self.session_id != session_id:
print("Failed to resume session, allocated a new session")
self.session_id = session_id
self.transport.deprecated_begin_session() self.transport.deprecated_begin_session()
self.session_counter += 1 self.session_counter += 1
def resume_session(self) -> None:
print("resume session")
new_id = self.transport.resume_session(self.session_id or b"")
if self.session_id != new_id:
print("Failed to resume session, allocated a new session")
self.session_id = new_id
def close(self) -> None: def close(self) -> None:
self.session_counter = max(self.session_counter - 1, 0) self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0: if self.session_counter == 0:
@ -151,8 +163,13 @@ class TrezorClient(Generic[UI]):
def call_raw(self, msg: "MessageType") -> "MessageType": def call_raw(self, msg: "MessageType") -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
print("self.call_raw-start")
self._raw_write(msg) self._raw_write(msg)
return self._raw_read() print("self.call_raw-after write")
x = self._raw_read()
print("self.call_raw-end")
return x
def _raw_write(self, msg: "MessageType") -> None: def _raw_write(self, msg: "MessageType") -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
@ -169,7 +186,9 @@ class TrezorClient(Generic[UI]):
def _raw_read(self) -> "MessageType": def _raw_read(self) -> "MessageType":
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
print("raw read - start")
msg_type, msg_bytes = self.transport.read() msg_type, msg_bytes = self.transport.read()
print("type/data", msg_type, msg_bytes)
LOG.log( LOG.log(
DUMP_BYTES, DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
@ -253,6 +272,7 @@ class TrezorClient(Generic[UI]):
@session @session
def call(self, msg: "MessageType") -> "MessageType": def call(self, msg: "MessageType") -> "MessageType":
print("self.call-start")
self.check_firmware_version() self.check_firmware_version()
resp = self.call_raw(msg) resp = self.call_raw(msg)
while True: while True:
@ -263,10 +283,13 @@ class TrezorClient(Generic[UI]):
elif isinstance(resp, messages.ButtonRequest): elif isinstance(resp, messages.ButtonRequest):
resp = self._callback_button(resp) resp = self._callback_button(resp)
elif isinstance(resp, messages.Failure): elif isinstance(resp, messages.Failure):
print("self.call-failure")
if resp.code == messages.FailureType.ActionCancelled: if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp) raise exceptions.TrezorFailure(resp)
else: else:
print("self.call-end")
return resp return resp
def _refresh_features(self, features: messages.Features) -> None: def _refresh_features(self, features: messages.Features) -> None:
@ -311,7 +334,7 @@ class TrezorClient(Generic[UI]):
self._refresh_features(resp) self._refresh_features(resp)
return resp return resp
@session # @session
def init_device( def init_device(
self, self,
*, *,
@ -352,11 +375,14 @@ class TrezorClient(Generic[UI]):
elif session_id is not None: elif session_id is not None:
self.session_id = session_id self.session_id = session_id
print("before init conn")
resp = self.transport.initialize_connection( resp = self.transport.initialize_connection(
mapping=self.mapping, mapping=self.mapping,
session_id=session_id, session_id=session_id,
derive_cardano=derive_cardano, derive_cardano=derive_cardano,
) )
print("here")
if isinstance(resp, messages.Failure): if isinstance(resp, messages.Failure):
# can happen if `derive_cardano` does not match the current session # can happen if `derive_cardano` does not match the current session
raise exceptions.TrezorFailure(resp) raise exceptions.TrezorFailure(resp)
@ -377,6 +403,7 @@ class TrezorClient(Generic[UI]):
# exchange happens. # exchange happens.
reported_session_id = resp.session_id reported_session_id = resp.session_id
self._refresh_features(resp) self._refresh_features(resp)
print("there:", reported_session_id)
return reported_session_id return reported_session_id
def is_outdated(self) -> bool: def is_outdated(self) -> bool:
@ -467,14 +494,19 @@ class TrezorClient(Generic[UI]):
This is a no-op in bootloader mode, as it does not support session management. This is a no-op in bootloader mode, as it does not support session management.
""" """
# since: 2.3.4, 1.9.4 # since: 2.3.4, 1.9.4
print("end session")
try: try:
if not self.features.bootloader_mode: if not self.features.bootloader_mode:
self.call(messages.EndSession()) self.transport.end_session(self.session_id or b"")
# self.call(messages.EndSession())
except exceptions.TrezorFailure: except exceptions.TrezorFailure:
# A failure most likely means that the FW version does not support # A failure most likely means that the FW version does not support
# the EndSession call. We ignore the failure and clear the local session_id. # the EndSession call. We ignore the failure and clear the local session_id.
# The client-side end result is identical. # The client-side end result is identical.
pass pass
except ValueError as e:
print(e)
print(e.args)
self.session_id = None self.session_id = None
@session @session

View File

@ -48,7 +48,7 @@ from .client import TrezorClient
from .exceptions import TrezorFailure from .exceptions import TrezorFailure
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import DebugWaitType
from .tools import expect, session from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
@ -1086,7 +1086,7 @@ class TrezorClientDebugLink(TrezorClient):
""" """
if not self.in_with_statement: if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement") raise RuntimeError("Must be called inside 'with' statement")
if input_flow is None: if input_flow is None:
self.ui.input_flow = None self.ui.input_flow = None
return return
@ -1287,7 +1287,7 @@ class TrezorClientDebugLink(TrezorClient):
# Start by canceling whatever is on screen. This will work to cancel T1 PIN # Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`. # prompt, which is in TINY mode and does not respond to `Ping`.
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
self.transport.begin_session() self.transport.deprecated_begin_session()
try: try:
self.transport.write(*cancel_msg) self.transport.write(*cancel_msg)
@ -1302,7 +1302,7 @@ class TrezorClientDebugLink(TrezorClient):
except Exception: except Exception:
pass pass
finally: finally:
self.transport.end_session() self.transport.end_session(self.session_id or b"")
def mnemonic_callback(self, _) -> str: def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()

View File

@ -63,9 +63,10 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None: if wire_type is None:
raise ValueError("Cannot encode class without wire type") raise ValueError("Cannot encode class without wire type")
print("wire type", wire_type)
buf = io.BytesIO() buf = io.BytesIO()
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)
print("test")
return wire_type, buf.getvalue() return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:

View File

@ -297,6 +297,7 @@ def session(
return f(client, *args, **kwargs) return f(client, *args, **kwargs)
finally: finally:
client.close() client.close()
print("wrap end")
return wrapped_f return wrapped_f

View File

@ -95,7 +95,7 @@ class Protocol:
def resume_session(self, session_id: bytes) -> bytes: def resume_session(self, session_id: bytes) -> bytes:
raise NotImplementedError raise NotImplementedError
def end_session(self, session_id: bytes) -> bytes: def end_session(self, session_id: bytes) -> None:
raise NotImplementedError raise NotImplementedError
# XXX we might be able to remove this now that TrezorClient does session handling # XXX we might be able to remove this now that TrezorClient does session handling
@ -147,7 +147,7 @@ class ProtocolBasedTransport(Transport):
def resume_session(self, session_id: bytes) -> bytes: def resume_session(self, session_id: bytes) -> bytes:
return self.protocol.resume_session(session_id) return self.protocol.resume_session(session_id)
def end_session(self, session_id: bytes) -> bytes: def end_session(self, session_id: bytes) -> None:
return self.protocol.end_session(session_id) return self.protocol.end_session(session_id)
def deprecated_begin_session(self) -> None: def deprecated_begin_session(self) -> None:

View File

@ -70,3 +70,6 @@ class ProtocolV1(Protocol):
if chunk[:1] != b"?": if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters") raise RuntimeError("Unexpected magic characters")
return chunk[1:] return chunk[1:]
def end_session(self, session_id: bytes) -> None:
return super().end_session(session_id)

View File

@ -1,6 +1,346 @@
import hashlib
import hmac
import logging
import os
from binascii import hexlify
from enum import IntEnum
from typing import Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .. import messages
from ..mapping import ProtobufMapping
from ..protobuf import MessageType
from ..transport.protocol import Handle, Protocol from ..transport.protocol import Handle, Protocol
from .thp import checksum, curve25519, thp_io
from .thp.checksum import CHECKSUM_LENGTH
from .thp.packet_header import PacketHeader
LOG = logging.getLogger(__name__)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(Protocol): class ProtocolV2(Protocol):
def __init__(self, handle: Handle) -> None: def __init__(self, handle: Handle) -> None:
super().__init__(handle) super().__init__(handle)
def initialize_connection(
self,
mapping: ProtobufMapping,
session_id: Optional[bytes] = None,
derive_caradano: Optional[bool] = None,
):
self.session_id: int = 0
self.sync_bit_send: int = 0
self.sync_bit_receive: int = 0
self.mapping = mapping
# Send channel allocation request
channel_id_request_nonce = os.urandom(8)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle,
PacketHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
)
# Read channel allocation response
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
):
print("TODO raise exception here, I guess")
self.cid = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
# Send handshake init request
ha_init_req_header = PacketHeader(0, self.cid, 36)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake init response
header, payload = self._read_until_valid_crc_check()
self._send_ack_1()
if not header.is_handshake_init_response():
print("Received message is not a valid handshake init response message")
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
print("noise_tag: ", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = PacketHeader(
0x12,
self.cid,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
print("Received message is not a valid handshake completion response")
self._send_ack_2()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request: int = 0
self.nonce_response: int = 1
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = mapping.encode(message)
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ")
# Read
_, msg_type, msg_data = self.read_and_decrypt()
maaa = mapping.decode(msg_type, msg_data)
self._send_ack_1()
assert isinstance(maaa, messages.ThpEndResponse)
# Send get features
message = messages.GetFeatures()
message_type, message_data = mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
_ = thp_io.read(self.handle)
session_id, msg_type, msg_data = self.read_and_decrypt()
features = mapping.decode(msg_type, msg_data)
assert isinstance(features, messages.Features)
features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
self._send_ack_2()
return features
def _encrypt_and_write(
self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
data = self.session_id.to_bytes(1, "big") + message_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader(
ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, header, encrypted_message
)
def _send_ack_1(self):
header = PacketHeader(0x20, self.cid, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
def _send_ack_2(self):
header = PacketHeader(0x28, self.cid, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
def _write_message(self, message: MessageType, mapping: ProtobufMapping):
try:
message_type, message_data = mapping.encode(message)
self.write(message_type, message_data)
except Exception as e:
print(type(e))
def write(self, message_type: int, message_data: bytes) -> None:
data = (
self.session_id.to_bytes(1, "big")
+ message_type.to_bytes(2, "big")
+ message_data
)
ctrl_byte = 0x04
self._write_and_encrypt(data, ctrl_byte)
def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None:
aes_ctx = AESGCM(self.key_request)
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_data = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader(
ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.handle, header, encrypted_data
)
def read_and_decrypt(self) -> Tuple[bytes, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!")
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
int.to_bytes(session_id, 1, "big"),
int.from_bytes(message_type, "big"),
message_data,
)
def end_session(self, session_id: bytes) -> None:
pass
def resume_session(self, session_id: bytes) -> bytes:
print("protocol 2 resume session")
return self.start_session("")
def start_session(self, passphrase: str) -> bytes:
try:
msg = messages.ThpCreateNewSession(passphrase=passphrase)
except Exception as e:
print(e)
print("s")
self._write_message(msg, self.mapping)
print("p")
response_type, response_data = self._read_until_valid_crc_check()
print(response_type, response_data)
return b""
def read(self) -> Tuple[int, bytes]:
header, raw_payload, chksum = thp_io.read(self.handle)
print("Read message", hexlify(raw_payload))
return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
def _get_control_byte(self) -> bytes:
return b"\x42"
def _read_until_valid_crc_check(
self,
) -> Tuple[PacketHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.handle)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
print(hexlify(header.to_bytes_init() + payload + chksum))
LOG.debug("Received a message with invalid checksum")
header, payload, chksum = thp_io.read(self.handle)
return header, payload
def _is_valid_channel_allocation_response(
self, header: PacketHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
print("Received message is not a channel allocation response")
return False
if len(payload) < 10:
print("Invalid channel allocation response payload")
return False
if payload[:8] != original_nonce:
print("Invalid channel allocation response payload (nonce mismatch)")
return False
return True
class ControlByteType(IntEnum):
CHANNEL_ALLOCATION_RES = 1
HANDSHAKE_INIT_RES = 2
HANDSHAKE_COMP_RES = 3
ACK = 4
ENCRYPTED_TRANSPORT = 5

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class PacketHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,89 @@
import struct
from binascii import hexlify
from typing import Tuple
from ..protocol import Handle
from ..thp import checksum
from .packet_header import PacketHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
PACKET_LENGTH = 64
CHECKSUM_LENGTH = 4
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
handle: Handle, header: PacketHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(handle, header, data)
print("WOO")
def write_payload_to_wire(
handle: Handle, header: PacketHeader, transport_payload: bytes
):
print("tttt")
handle.open()
buffer = bytearray(transport_payload)
chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH]
print("x")
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
print("y")
print(hexlify(chunk))
handle.write_chunk(chunk)
print("fgh")
buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :]
while buffer:
chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH]
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
handle.write_chunk(chunk)
buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :]
def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]:
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(handle)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(handle, header.cid))
# print("buffer read (data):", hexlify(buffer).decode())
# print("buffer len (data):", datalen)
# TODO check checksum?? or do not strip ?
data_len = header.data_length - CHECKSUM_LENGTH
return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH]
def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
chunk = handle.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return PacketHeader(ctrl_byte, cid, data_length), data
def read_next(handle: Handle, cid: int) -> bytes:
chunk = handle.read_chunk()
ctrl_byte, read_cid = struct.unpack(
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]

View File

@ -69,7 +69,9 @@ class WebUsbHandle:
self.handle = None self.handle = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
print("ti")
assert self.handle is not None assert self.handle is not None
print("te")
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}") LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")