mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
TEMPORARY WIP - trezorlib
This commit is contained in:
parent
06cc68cc46
commit
f6ff8529c6
@ -653,6 +653,7 @@ def update(
|
|||||||
against downloaded firmware fingerprint. Otherwise fingerprint is checked
|
against downloaded firmware fingerprint. Otherwise fingerprint is checked
|
||||||
against data.trezor.io information, if available.
|
against data.trezor.io information, if available.
|
||||||
"""
|
"""
|
||||||
|
print("client context")
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
if sum(bool(x) for x in (filename, url, version)) > 1:
|
if sum(bool(x) for x in (filename, url, version)) > 1:
|
||||||
click.echo("You can use only one of: filename, url, version.")
|
click.echo("You can use only one of: filename, url, version.")
|
||||||
|
@ -291,6 +291,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
|||||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||||
description = format_device_name(client.features)
|
description = format_device_name(client.features)
|
||||||
client.end_session()
|
client.end_session()
|
||||||
|
print("after end session")
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -134,12 +134,24 @@ class TrezorClient(Generic[UI]):
|
|||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
if _init_device:
|
if _init_device:
|
||||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
||||||
|
self.resume_session()
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
|
session_id = self.transport.resume_session(b"")
|
||||||
|
if self.session_id != session_id:
|
||||||
|
print("Failed to resume session, allocated a new session")
|
||||||
|
self.session_id = session_id
|
||||||
self.transport.deprecated_begin_session()
|
self.transport.deprecated_begin_session()
|
||||||
self.session_counter += 1
|
self.session_counter += 1
|
||||||
|
|
||||||
|
def resume_session(self) -> None:
|
||||||
|
print("resume session")
|
||||||
|
new_id = self.transport.resume_session(self.session_id or b"")
|
||||||
|
if self.session_id != new_id:
|
||||||
|
print("Failed to resume session, allocated a new session")
|
||||||
|
self.session_id = new_id
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
self.session_counter = max(self.session_counter - 1, 0)
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
@ -151,8 +163,13 @@ class TrezorClient(Generic[UI]):
|
|||||||
|
|
||||||
def call_raw(self, msg: "MessageType") -> "MessageType":
|
def call_raw(self, msg: "MessageType") -> "MessageType":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
print("self.call_raw-start")
|
||||||
|
|
||||||
self._raw_write(msg)
|
self._raw_write(msg)
|
||||||
return self._raw_read()
|
print("self.call_raw-after write")
|
||||||
|
x = self._raw_read()
|
||||||
|
print("self.call_raw-end")
|
||||||
|
return x
|
||||||
|
|
||||||
def _raw_write(self, msg: "MessageType") -> None:
|
def _raw_write(self, msg: "MessageType") -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
@ -169,7 +186,9 @@ class TrezorClient(Generic[UI]):
|
|||||||
|
|
||||||
def _raw_read(self) -> "MessageType":
|
def _raw_read(self) -> "MessageType":
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
print("raw read - start")
|
||||||
msg_type, msg_bytes = self.transport.read()
|
msg_type, msg_bytes = self.transport.read()
|
||||||
|
print("type/data", msg_type, msg_bytes)
|
||||||
LOG.log(
|
LOG.log(
|
||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
@ -253,6 +272,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
|
|
||||||
@session
|
@session
|
||||||
def call(self, msg: "MessageType") -> "MessageType":
|
def call(self, msg: "MessageType") -> "MessageType":
|
||||||
|
print("self.call-start")
|
||||||
self.check_firmware_version()
|
self.check_firmware_version()
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
while True:
|
while True:
|
||||||
@ -263,10 +283,13 @@ class TrezorClient(Generic[UI]):
|
|||||||
elif isinstance(resp, messages.ButtonRequest):
|
elif isinstance(resp, messages.ButtonRequest):
|
||||||
resp = self._callback_button(resp)
|
resp = self._callback_button(resp)
|
||||||
elif isinstance(resp, messages.Failure):
|
elif isinstance(resp, messages.Failure):
|
||||||
|
print("self.call-failure")
|
||||||
|
|
||||||
if resp.code == messages.FailureType.ActionCancelled:
|
if resp.code == messages.FailureType.ActionCancelled:
|
||||||
raise exceptions.Cancelled
|
raise exceptions.Cancelled
|
||||||
raise exceptions.TrezorFailure(resp)
|
raise exceptions.TrezorFailure(resp)
|
||||||
else:
|
else:
|
||||||
|
print("self.call-end")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _refresh_features(self, features: messages.Features) -> None:
|
def _refresh_features(self, features: messages.Features) -> None:
|
||||||
@ -311,7 +334,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
self._refresh_features(resp)
|
self._refresh_features(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@session
|
# @session
|
||||||
def init_device(
|
def init_device(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -352,11 +375,14 @@ class TrezorClient(Generic[UI]):
|
|||||||
elif session_id is not None:
|
elif session_id is not None:
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
|
||||||
|
print("before init conn")
|
||||||
|
|
||||||
resp = self.transport.initialize_connection(
|
resp = self.transport.initialize_connection(
|
||||||
mapping=self.mapping,
|
mapping=self.mapping,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
derive_cardano=derive_cardano,
|
derive_cardano=derive_cardano,
|
||||||
)
|
)
|
||||||
|
print("here")
|
||||||
if isinstance(resp, messages.Failure):
|
if isinstance(resp, messages.Failure):
|
||||||
# can happen if `derive_cardano` does not match the current session
|
# can happen if `derive_cardano` does not match the current session
|
||||||
raise exceptions.TrezorFailure(resp)
|
raise exceptions.TrezorFailure(resp)
|
||||||
@ -377,6 +403,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
# exchange happens.
|
# exchange happens.
|
||||||
reported_session_id = resp.session_id
|
reported_session_id = resp.session_id
|
||||||
self._refresh_features(resp)
|
self._refresh_features(resp)
|
||||||
|
print("there:", reported_session_id)
|
||||||
return reported_session_id
|
return reported_session_id
|
||||||
|
|
||||||
def is_outdated(self) -> bool:
|
def is_outdated(self) -> bool:
|
||||||
@ -467,14 +494,19 @@ class TrezorClient(Generic[UI]):
|
|||||||
This is a no-op in bootloader mode, as it does not support session management.
|
This is a no-op in bootloader mode, as it does not support session management.
|
||||||
"""
|
"""
|
||||||
# since: 2.3.4, 1.9.4
|
# since: 2.3.4, 1.9.4
|
||||||
|
print("end session")
|
||||||
try:
|
try:
|
||||||
if not self.features.bootloader_mode:
|
if not self.features.bootloader_mode:
|
||||||
self.call(messages.EndSession())
|
self.transport.end_session(self.session_id or b"")
|
||||||
|
# self.call(messages.EndSession())
|
||||||
except exceptions.TrezorFailure:
|
except exceptions.TrezorFailure:
|
||||||
# A failure most likely means that the FW version does not support
|
# A failure most likely means that the FW version does not support
|
||||||
# the EndSession call. We ignore the failure and clear the local session_id.
|
# the EndSession call. We ignore the failure and clear the local session_id.
|
||||||
# The client-side end result is identical.
|
# The client-side end result is identical.
|
||||||
pass
|
pass
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
print(e.args)
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
@session
|
@session
|
||||||
|
@ -48,7 +48,7 @@ from .client import TrezorClient
|
|||||||
from .exceptions import TrezorFailure
|
from .exceptions import TrezorFailure
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import DebugWaitType
|
from .messages import DebugWaitType
|
||||||
from .tools import expect, session
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@ -1086,7 +1086,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
"""
|
"""
|
||||||
if not self.in_with_statement:
|
if not self.in_with_statement:
|
||||||
raise RuntimeError("Must be called inside 'with' statement")
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
|
|
||||||
if input_flow is None:
|
if input_flow is None:
|
||||||
self.ui.input_flow = None
|
self.ui.input_flow = None
|
||||||
return
|
return
|
||||||
@ -1287,7 +1287,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
||||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||||
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||||
self.transport.begin_session()
|
self.transport.deprecated_begin_session()
|
||||||
try:
|
try:
|
||||||
self.transport.write(*cancel_msg)
|
self.transport.write(*cancel_msg)
|
||||||
|
|
||||||
@ -1302,7 +1302,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
self.transport.end_session()
|
self.transport.end_session(self.session_id or b"")
|
||||||
|
|
||||||
def mnemonic_callback(self, _) -> str:
|
def mnemonic_callback(self, _) -> str:
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
|
@ -63,9 +63,10 @@ class ProtobufMapping:
|
|||||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||||
if wire_type is None:
|
if wire_type is None:
|
||||||
raise ValueError("Cannot encode class without wire type")
|
raise ValueError("Cannot encode class without wire type")
|
||||||
|
print("wire type", wire_type)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
|
print("test")
|
||||||
return wire_type, buf.getvalue()
|
return wire_type, buf.getvalue()
|
||||||
|
|
||||||
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||||
|
@ -297,6 +297,7 @@ def session(
|
|||||||
return f(client, *args, **kwargs)
|
return f(client, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
client.close()
|
client.close()
|
||||||
|
print("wrap end")
|
||||||
|
|
||||||
return wrapped_f
|
return wrapped_f
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class Protocol:
|
|||||||
def resume_session(self, session_id: bytes) -> bytes:
|
def resume_session(self, session_id: bytes) -> bytes:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def end_session(self, session_id: bytes) -> bytes:
|
def end_session(self, session_id: bytes) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||||
@ -147,7 +147,7 @@ class ProtocolBasedTransport(Transport):
|
|||||||
def resume_session(self, session_id: bytes) -> bytes:
|
def resume_session(self, session_id: bytes) -> bytes:
|
||||||
return self.protocol.resume_session(session_id)
|
return self.protocol.resume_session(session_id)
|
||||||
|
|
||||||
def end_session(self, session_id: bytes) -> bytes:
|
def end_session(self, session_id: bytes) -> None:
|
||||||
return self.protocol.end_session(session_id)
|
return self.protocol.end_session(session_id)
|
||||||
|
|
||||||
def deprecated_begin_session(self) -> None:
|
def deprecated_begin_session(self) -> None:
|
||||||
|
@ -70,3 +70,6 @@ class ProtocolV1(Protocol):
|
|||||||
if chunk[:1] != b"?":
|
if chunk[:1] != b"?":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
return chunk[1:]
|
return chunk[1:]
|
||||||
|
|
||||||
|
def end_session(self, session_id: bytes) -> None:
|
||||||
|
return super().end_session(session_id)
|
||||||
|
@ -1,6 +1,346 @@
|
|||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from binascii import hexlify
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||||
|
|
||||||
|
from .. import messages
|
||||||
|
from ..mapping import ProtobufMapping
|
||||||
|
from ..protobuf import MessageType
|
||||||
from ..transport.protocol import Handle, Protocol
|
from ..transport.protocol import Handle, Protocol
|
||||||
|
from .thp import checksum, curve25519, thp_io
|
||||||
|
from .thp.checksum import CHECKSUM_LENGTH
|
||||||
|
from .thp.packet_header import PacketHeader
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||||
|
hash = hashlib.sha256(val_1)
|
||||||
|
hash.update(val_2)
|
||||||
|
return hash.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def _hkdf(chaining_key: bytes, input: bytes):
|
||||||
|
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||||
|
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||||
|
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||||
|
ctx_output_2.update(b"\x02")
|
||||||
|
output_2 = ctx_output_2.digest()
|
||||||
|
return (output_1, output_2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||||
|
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||||
|
raise ValueError("Nonce overflow, terminate the channel")
|
||||||
|
return bytes(4) + nonce.to_bytes(8, "big")
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV2(Protocol):
|
class ProtocolV2(Protocol):
|
||||||
def __init__(self, handle: Handle) -> None:
|
def __init__(self, handle: Handle) -> None:
|
||||||
super().__init__(handle)
|
super().__init__(handle)
|
||||||
|
|
||||||
|
def initialize_connection(
|
||||||
|
self,
|
||||||
|
mapping: ProtobufMapping,
|
||||||
|
session_id: Optional[bytes] = None,
|
||||||
|
derive_caradano: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
self.session_id: int = 0
|
||||||
|
self.sync_bit_send: int = 0
|
||||||
|
self.sync_bit_receive: int = 0
|
||||||
|
self.mapping = mapping
|
||||||
|
# Send channel allocation request
|
||||||
|
channel_id_request_nonce = os.urandom(8)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.handle,
|
||||||
|
PacketHeader.get_channel_allocation_request_header(12),
|
||||||
|
channel_id_request_nonce,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read channel allocation response
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not self._is_valid_channel_allocation_response(
|
||||||
|
header, payload, channel_id_request_nonce
|
||||||
|
):
|
||||||
|
print("TODO raise exception here, I guess")
|
||||||
|
|
||||||
|
self.cid = int.from_bytes(payload[8:10], "big")
|
||||||
|
self.device_properties = payload[10:]
|
||||||
|
|
||||||
|
# Send handshake init request
|
||||||
|
ha_init_req_header = PacketHeader(0, self.cid, 36)
|
||||||
|
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||||
|
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.handle, ha_init_req_header, host_ephemeral_pubkey
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
print("Received message is not a valid ACK ")
|
||||||
|
|
||||||
|
# Read handshake init response
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
self._send_ack_1()
|
||||||
|
|
||||||
|
if not header.is_handshake_init_response():
|
||||||
|
print("Received message is not a valid handshake init response message")
|
||||||
|
|
||||||
|
trezor_ephemeral_pubkey = payload[:32]
|
||||||
|
encrypted_trezor_static_pubkey = payload[32:80]
|
||||||
|
noise_tag = payload[80:96]
|
||||||
|
|
||||||
|
# TODO check noise tag
|
||||||
|
print("noise_tag: ", hexlify(noise_tag).decode())
|
||||||
|
|
||||||
|
# Prepare and send handshake completion request
|
||||||
|
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||||
|
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||||
|
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||||
|
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||||
|
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||||
|
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
PROTOCOL_NAME,
|
||||||
|
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
try:
|
||||||
|
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||||
|
IV_1, encrypted_trezor_static_pubkey, h
|
||||||
|
)
|
||||||
|
# print("masked_key", hexlify(trezor_masked_static_pubkey).decode())
|
||||||
|
except Exception as e:
|
||||||
|
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||||
|
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||||
|
)
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||||
|
h = _sha256_of_two(h, tag_of_empty_string)
|
||||||
|
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||||
|
|
||||||
|
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||||
|
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||||
|
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||||
|
)
|
||||||
|
msg_data = mapping.encode_without_wire_type(
|
||||||
|
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||||
|
pairing_methods=[
|
||||||
|
messages.ThpPairingMethod.NoMethod,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_payload)
|
||||||
|
ha_completion_req_header = PacketHeader(
|
||||||
|
0x12,
|
||||||
|
self.cid,
|
||||||
|
len(encrypted_host_static_pubkey)
|
||||||
|
+ len(encrypted_payload)
|
||||||
|
+ CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.handle,
|
||||||
|
ha_completion_req_header,
|
||||||
|
encrypted_host_static_pubkey + encrypted_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
print("Received message is not a valid ACK ")
|
||||||
|
|
||||||
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
|
header, _ = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_handshake_comp_response():
|
||||||
|
print("Received message is not a valid handshake completion response")
|
||||||
|
self._send_ack_2()
|
||||||
|
|
||||||
|
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||||
|
self.nonce_request: int = 0
|
||||||
|
self.nonce_response: int = 1
|
||||||
|
|
||||||
|
# Send StartPairingReqest message
|
||||||
|
message = messages.ThpStartPairingRequest()
|
||||||
|
message_type, message_data = mapping.encode(message)
|
||||||
|
|
||||||
|
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
print("Received message is not a valid ACK ")
|
||||||
|
|
||||||
|
# Read
|
||||||
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
maaa = mapping.decode(msg_type, msg_data)
|
||||||
|
self._send_ack_1()
|
||||||
|
|
||||||
|
assert isinstance(maaa, messages.ThpEndResponse)
|
||||||
|
|
||||||
|
# Send get features
|
||||||
|
message = messages.GetFeatures()
|
||||||
|
message_type, message_data = mapping.encode(message)
|
||||||
|
|
||||||
|
self.session_id: int = 0
|
||||||
|
self._encrypt_and_write(message_type.to_bytes(2, "big"), message_data, 0x14)
|
||||||
|
_ = thp_io.read(self.handle)
|
||||||
|
session_id, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
features = mapping.decode(msg_type, msg_data)
|
||||||
|
assert isinstance(features, messages.Features)
|
||||||
|
features.session_id = int.to_bytes(self.cid, 2, "big") + session_id
|
||||||
|
self._send_ack_2()
|
||||||
|
return features
|
||||||
|
|
||||||
|
def _encrypt_and_write(
|
||||||
|
self, message_type: bytes, message_data: bytes, ctrl_byte: int = 0x04
|
||||||
|
) -> None:
|
||||||
|
assert self.key_request is not None
|
||||||
|
aes_ctx = AESGCM(self.key_request)
|
||||||
|
data = self.session_id.to_bytes(1, "big") + message_type + message_data
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||||
|
self.nonce_request += 1
|
||||||
|
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||||
|
header = PacketHeader(
|
||||||
|
ctrl_byte, self.cid, len(encrypted_message) + CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.handle, header, encrypted_message
|
||||||
|
)
|
||||||
|
|
||||||
|
def _send_ack_1(self):
|
||||||
|
header = PacketHeader(0x20, self.cid, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
|
||||||
|
|
||||||
|
def _send_ack_2(self):
|
||||||
|
header = PacketHeader(0x28, self.cid, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.handle, header, b"")
|
||||||
|
|
||||||
|
def _write_message(self, message: MessageType, mapping: ProtobufMapping):
|
||||||
|
try:
|
||||||
|
message_type, message_data = mapping.encode(message)
|
||||||
|
self.write(message_type, message_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(type(e))
|
||||||
|
|
||||||
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
|
data = (
|
||||||
|
self.session_id.to_bytes(1, "big")
|
||||||
|
+ message_type.to_bytes(2, "big")
|
||||||
|
+ message_data
|
||||||
|
)
|
||||||
|
ctrl_byte = 0x04
|
||||||
|
self._write_and_encrypt(data, ctrl_byte)
|
||||||
|
|
||||||
|
def _write_and_encrypt(self, data: bytes, ctrl_byte: int) -> None:
|
||||||
|
aes_ctx = AESGCM(self.key_request)
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||||
|
self.nonce_request += 1
|
||||||
|
encrypted_data = aes_ctx.encrypt(nonce, data, b"")
|
||||||
|
header = PacketHeader(
|
||||||
|
ctrl_byte, self.cid, len(encrypted_data) + CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.handle, header, encrypted_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_and_decrypt(self) -> Tuple[bytes, int, bytes]:
|
||||||
|
header, raw_payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_encrypted_transport():
|
||||||
|
print("Trying to decrypt not encrypted message!")
|
||||||
|
aes_ctx = AESGCM(self.key_response)
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||||
|
self.nonce_response += 1
|
||||||
|
|
||||||
|
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||||
|
session_id = message[0]
|
||||||
|
message_type = message[1:3]
|
||||||
|
message_data = message[3:]
|
||||||
|
return (
|
||||||
|
int.to_bytes(session_id, 1, "big"),
|
||||||
|
int.from_bytes(message_type, "big"),
|
||||||
|
message_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def end_session(self, session_id: bytes) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def resume_session(self, session_id: bytes) -> bytes:
|
||||||
|
print("protocol 2 resume session")
|
||||||
|
return self.start_session("")
|
||||||
|
|
||||||
|
def start_session(self, passphrase: str) -> bytes:
|
||||||
|
try:
|
||||||
|
msg = messages.ThpCreateNewSession(passphrase=passphrase)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("s")
|
||||||
|
|
||||||
|
self._write_message(msg, self.mapping)
|
||||||
|
print("p")
|
||||||
|
response_type, response_data = self._read_until_valid_crc_check()
|
||||||
|
print(response_type, response_data)
|
||||||
|
return b""
|
||||||
|
|
||||||
|
def read(self) -> Tuple[int, bytes]:
|
||||||
|
header, raw_payload, chksum = thp_io.read(self.handle)
|
||||||
|
print("Read message", hexlify(raw_payload))
|
||||||
|
return (0x00, header.to_bytes_init() + raw_payload + chksum) # TODO change
|
||||||
|
|
||||||
|
def _get_control_byte(self) -> bytes:
|
||||||
|
return b"\x42"
|
||||||
|
|
||||||
|
def _read_until_valid_crc_check(
|
||||||
|
self,
|
||||||
|
) -> Tuple[PacketHeader, bytes]:
|
||||||
|
is_valid = False
|
||||||
|
header, payload, chksum = thp_io.read(self.handle)
|
||||||
|
while not is_valid:
|
||||||
|
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||||
|
if not is_valid:
|
||||||
|
print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||||
|
LOG.debug("Received a message with invalid checksum")
|
||||||
|
header, payload, chksum = thp_io.read(self.handle)
|
||||||
|
|
||||||
|
return header, payload
|
||||||
|
|
||||||
|
def _is_valid_channel_allocation_response(
|
||||||
|
self, header: PacketHeader, payload: bytes, original_nonce: bytes
|
||||||
|
) -> bool:
|
||||||
|
if not header.is_channel_allocation_response():
|
||||||
|
print("Received message is not a channel allocation response")
|
||||||
|
return False
|
||||||
|
if len(payload) < 10:
|
||||||
|
print("Invalid channel allocation response payload")
|
||||||
|
return False
|
||||||
|
if payload[:8] != original_nonce:
|
||||||
|
print("Invalid channel allocation response payload (nonce mismatch)")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
class ControlByteType(IntEnum):
|
||||||
|
CHANNEL_ALLOCATION_RES = 1
|
||||||
|
HANDSHAKE_INIT_RES = 2
|
||||||
|
HANDSHAKE_COMP_RES = 3
|
||||||
|
ACK = 4
|
||||||
|
ENCRYPTED_TRANSPORT = 5
|
||||||
|
82
python/src/trezorlib/transport/thp/packet_header.py
Normal file
82
python/src/trezorlib/transport/thp/packet_header.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
CODEC_V1 = 0x3F
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
HANDSHAKE_INIT_REQ = 0x00
|
||||||
|
HANDSHAKE_INIT_RES = 0x01
|
||||||
|
HANDSHAKE_COMP_REQ = 0x02
|
||||||
|
HANDSHAKE_COMP_RES = 0x03
|
||||||
|
ENCRYPTED_TRANSPORT = 0x04
|
||||||
|
|
||||||
|
CONTINUATION_PACKET_MASK = 0x80
|
||||||
|
ACK_MASK = 0xF7
|
||||||
|
DATA_MASK = 0xE7
|
||||||
|
|
||||||
|
ACK_MESSAGE = 0x20
|
||||||
|
_ERROR = 0x42
|
||||||
|
CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x41
|
||||||
|
|
||||||
|
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
|
TREZOR_STATE_PAIRED = b"\x01"
|
||||||
|
|
||||||
|
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||||
|
|
||||||
|
|
||||||
|
class PacketHeader:
|
||||||
|
format_str_init = ">BHH"
|
||||||
|
format_str_cont = ">BH"
|
||||||
|
|
||||||
|
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||||
|
self.ctrl_byte = ctrl_byte
|
||||||
|
self.cid = cid
|
||||||
|
self.data_length = length
|
||||||
|
|
||||||
|
def to_bytes_init(self) -> bytes:
|
||||||
|
return struct.pack(
|
||||||
|
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_bytes_cont(self) -> bytes:
|
||||||
|
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
|
||||||
|
|
||||||
|
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_init,
|
||||||
|
buffer,
|
||||||
|
buffer_offset,
|
||||||
|
self.ctrl_byte,
|
||||||
|
self.cid,
|
||||||
|
self.data_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_ack(self) -> bool:
|
||||||
|
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||||
|
|
||||||
|
def is_channel_allocation_response(self):
|
||||||
|
return (
|
||||||
|
self.cid == BROADCAST_CHANNEL_ID
|
||||||
|
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_handshake_init_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
|
||||||
|
|
||||||
|
def is_handshake_comp_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
|
||||||
|
|
||||||
|
def is_encrypted_transport(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_error_header(cls, cid: int, length: int):
|
||||||
|
return cls(_ERROR, cid, length)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_channel_allocation_request_header(cls, length: int):
|
||||||
|
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)
|
89
python/src/trezorlib/transport/thp/thp_io.py
Normal file
89
python/src/trezorlib/transport/thp/thp_io.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import struct
|
||||||
|
from binascii import hexlify
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from ..protocol import Handle
|
||||||
|
from ..thp import checksum
|
||||||
|
from .packet_header import PacketHeader
|
||||||
|
|
||||||
|
INIT_HEADER_LENGTH = 5
|
||||||
|
CONT_HEADER_LENGTH = 3
|
||||||
|
PACKET_LENGTH = 64
|
||||||
|
CHECKSUM_LENGTH = 4
|
||||||
|
MAX_PAYLOAD_LEN = 60000
|
||||||
|
MESSAGE_TYPE_LENGTH = 2
|
||||||
|
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire_and_add_checksum(
|
||||||
|
handle: Handle, header: PacketHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||||
|
data = transport_payload + chksum
|
||||||
|
write_payload_to_wire(handle, header, data)
|
||||||
|
print("WOO")
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire(
|
||||||
|
handle: Handle, header: PacketHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
print("tttt")
|
||||||
|
handle.open()
|
||||||
|
buffer = bytearray(transport_payload)
|
||||||
|
chunk = header.to_bytes_init() + buffer[: PACKET_LENGTH - INIT_HEADER_LENGTH]
|
||||||
|
print("x")
|
||||||
|
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
|
||||||
|
print("y")
|
||||||
|
print(hexlify(chunk))
|
||||||
|
handle.write_chunk(chunk)
|
||||||
|
print("fgh")
|
||||||
|
|
||||||
|
buffer = buffer[PACKET_LENGTH - INIT_HEADER_LENGTH :]
|
||||||
|
while buffer:
|
||||||
|
chunk = header.to_bytes_cont() + buffer[: PACKET_LENGTH - CONT_HEADER_LENGTH]
|
||||||
|
chunk = chunk.ljust(PACKET_LENGTH, b"\x00")
|
||||||
|
handle.write_chunk(chunk)
|
||||||
|
buffer = buffer[PACKET_LENGTH - CONT_HEADER_LENGTH :]
|
||||||
|
|
||||||
|
|
||||||
|
def read(handle: Handle) -> Tuple[PacketHeader, bytes, bytes]:
|
||||||
|
buffer = bytearray()
|
||||||
|
# Read header with first part of message data
|
||||||
|
header, first_chunk = read_first(handle)
|
||||||
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
|
# Read the rest of the message
|
||||||
|
while len(buffer) < header.data_length:
|
||||||
|
buffer.extend(read_next(handle, header.cid))
|
||||||
|
# print("buffer read (data):", hexlify(buffer).decode())
|
||||||
|
# print("buffer len (data):", datalen)
|
||||||
|
# TODO check checksum?? or do not strip ?
|
||||||
|
data_len = header.data_length - CHECKSUM_LENGTH
|
||||||
|
return header, buffer[:data_len], buffer[data_len : data_len + CHECKSUM_LENGTH]
|
||||||
|
|
||||||
|
|
||||||
|
def read_first(handle: Handle) -> Tuple[PacketHeader, bytes]:
|
||||||
|
chunk = handle.read_chunk()
|
||||||
|
try:
|
||||||
|
ctrl_byte, cid, data_length = struct.unpack(
|
||||||
|
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError("Cannot parse header")
|
||||||
|
|
||||||
|
data = chunk[INIT_HEADER_LENGTH:]
|
||||||
|
return PacketHeader(ctrl_byte, cid, data_length), data
|
||||||
|
|
||||||
|
|
||||||
|
def read_next(handle: Handle, cid: int) -> bytes:
|
||||||
|
chunk = handle.read_chunk()
|
||||||
|
ctrl_byte, read_cid = struct.unpack(
|
||||||
|
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
if ctrl_byte != CONTINUATION_PACKET:
|
||||||
|
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||||
|
if read_cid != cid:
|
||||||
|
raise RuntimeError("Continuation packet for different channel")
|
||||||
|
|
||||||
|
return chunk[CONT_HEADER_LENGTH:]
|
@ -69,7 +69,9 @@ class WebUsbHandle:
|
|||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None:
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
|
print("ti")
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
|
print("te")
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
||||||
|
Loading…
Reference in New Issue
Block a user