1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-22 02:55:43 +00:00

chore(core): adapt trezorlib transports to session based

[no changelog]
This commit is contained in:
M1nd3r 2025-02-04 15:21:19 +01:00
parent 07b857d1b3
commit aeca5223cc
11 changed files with 1122 additions and 8 deletions

View File

@ -34,6 +34,7 @@ TREZORD_HOST = "http://127.0.0.1:21325"
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
@ -64,17 +65,26 @@ def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
def supports_protocolV2() -> bool:
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = ProtocolVersion.PROTOCOL_V1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.deprecated_begin_session()
transport.deprecated_write(request_type, request_data)
response_type, response_data = transport.deprecated_read()
_ = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
transport.deprecated_begin_session()
transport.open()
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
response = transport.read_chunk()
response_type = int.from_bytes(response[:2], "big")
response_data = response[2:]
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = ProtocolVersion.PROTOCOL_V2
return protocol_version

View File

@ -6,6 +6,7 @@ import typing as t
from .. import exceptions, messages, models
from ..protobuf import MessageType
from .thp.protocol_v1 import ProtocolV1Channel
from .thp.protocol_v2 import ProtocolV2Channel
if t.TYPE_CHECKING:
from ..client import TrezorClient
@ -172,3 +173,50 @@ def derive_seed(session: Session) -> None:
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
session.refresh_features()
class SessionV2(Session):
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | None,
derive_cardano: bool,
session_id: int = 0,
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2Channel)
session = cls(client, session_id.to_bytes(1, "big"))
session.call(
messages.ThpCreateNewSession(
passphrase=passphrase, derive_cardano=derive_cardano
),
expect=messages.Success,
)
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
from ..debuglink import TrezorClientDebugLink
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2Channel)
helper_debug = None
if isinstance(client, TrezorClientDebugLink):
helper_debug = client.debug
self.channel: ProtocolV2Channel = client.protocol.get_channel(helper_debug)
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:
LOG.debug("writing message %s", type(msg))
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
return msg
def update_id_and_sid(self, id: bytes) -> None:
self._id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -0,0 +1,102 @@
# from storage.cache_thp import ChannelCache
# from trezor import log
# from trezor.wire.thp import ThpError
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
# """
# Checks if:
# - an ACK message is expected
# - the received ACK message acknowledges correct sequence number (bit)
# """
# if not _is_ack_expected(cache):
# return False
# if not _has_ack_correct_sync_bit(cache, ack_bit):
# return False
# return True
# def _is_ack_expected(cache: ChannelCache) -> bool:
# is_expected: bool = not is_sending_allowed(cache)
# if __debug__ and not is_expected:
# log.debug(__name__, "Received unexpected ACK message")
# return is_expected
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
# if __debug__ and not is_correct:
# log.debug(__name__, "Received ACK message with wrong ack bit")
# return is_correct
# def is_sending_allowed(cache: ChannelCache) -> bool:
# """
# Checks whether sending a message in the provided channel is allowed.
# Note: Sending a message in a channel before receipt of ACK message for the previously
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
# """
# return bool(cache.sync >> 7)
# def get_send_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the sequential number (bit) of the next message to be sent
# in the provided channel.
# """
# return (cache.sync & 0x20) >> 5
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the (expected) sequential number (bit) of the next message
# to be received in the provided channel.
# """
# return (cache.sync & 0x40) >> 6
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
# """
# Set the flag whether sending a message in this channel is allowed or not.
# """
# cache.sync &= 0x7F
# if sending_allowed:
# cache.sync |= 0x80
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# """
# Set the expected sequential number (bit) of the next message to be received
# in the provided channel
# """
# if __debug__:
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected receive sync bit")
# # set second bit to "seq_bit" value
# cache.sync &= 0xBF
# if seq_bit:
# cache.sync |= 0x40
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected send seq bit")
# if __debug__:
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# # set third bit to "seq_bit" value
# cache.sync &= 0xDF
# if seq_bit:
# cache.sync |= 0x20
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
# """
# Set the sequential bit of the "next message to be send" to the opposite value,
# i.e. 1 -> 0 and 0 -> 1
# """
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -6,7 +6,6 @@ import os
import typing as t
from .channel_data import ChannelData
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
@ -20,6 +19,11 @@ def get_channel_db() -> ChannelDatabase:
return db
if t.TYPE_CHECKING:
from .protocol_and_channel import Channel
from .protocol_v2 import ProtocolV2Channel
class ChannelDatabase:
def load_stored_channels(self) -> t.List[ChannelData]: ...
@ -73,7 +77,7 @@ class JsonChannelDatabase(ChannelDatabase):
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(self, new_channel: Channel):
def save_channel(self, new_channel: ProtocolV2Channel):
LOG.debug("save channel")
channels = self._read_all_channels()

View File

@ -0,0 +1,19 @@
import zlib
CHECKSUM_LENGTH = 4
def compute(data: bytes) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,63 @@
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise Exception("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise Exception("Unexpected acknowledgement bit")
def get_seq_bit(ctrl_byte: int) -> int:
return (ctrl_byte & 0x10) >> 4
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_error(ctrl_byte: int) -> bool:
return ctrl_byte == _ERROR
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,40 @@
import typing as t
from hashlib import sha512
from . import curve25519
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
random_bytes: t.Callable[[int], bytes]
def __init__(self, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.shared_secret: bytes
self.host_private_key: bytes
self.host_public_key: bytes
def generate_keys_and_secret(
self, code_code_entry: bytes, trezor_public_key: bytes
) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = curve25519.elligator2(pregenerator)
self.host_private_key = self.random_bytes(32)
self.host_public_key = curve25519.multiply(self.host_private_key, generator)
self.shared_secret = curve25519.multiply(
self.host_private_key, trezor_public_key
)

View File

@ -0,0 +1,159 @@
from typing import Tuple
p = 2**255 - 19
J = 486662
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
a24 = 121666 # (J + 2) // 4
def decode_scalar(scalar: bytes) -> int:
# decodeScalar25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(scalar) != 32:
raise ValueError("Invalid length of scalar")
array = bytearray(scalar)
array[0] &= 248
array[31] &= 127
array[31] |= 64
return int.from_bytes(array, "little")
def decode_coordinate(coordinate: bytes) -> int:
# decodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(coordinate) != 32:
raise ValueError("Invalid length of coordinate")
array = bytearray(coordinate)
array[-1] &= 0x7F
return int.from_bytes(array, "little") % p
def encode_coordinate(coordinate: int) -> bytes:
# encodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
return coordinate.to_bytes(32, "little")
def get_private_key(secret: bytes) -> bytes:
return decode_scalar(secret).to_bytes(32, "little")
def get_public_key(private_key: bytes) -> bytes:
base_point = int.to_bytes(9, 32, "little")
return multiply(private_key, base_point)
def multiply(private_scalar: bytes, public_point: bytes):
# X25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
def ladder_operation(
x1: int, x2: int, z2: int, x3: int, z3: int
) -> Tuple[int, int, int, int]:
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
# (x4, z4) = 2 * (x2, z2)
# (x5, z5) = (x2, z2) + (x3, z3)
# where (x1, 1) = (x3, z3) - (x2, z2)
a = (x2 + z2) % p
aa = (a * a) % p
b = (x2 - z2) % p
bb = (b * b) % p
e = (aa - bb) % p
c = (x3 + z3) % p
d = (x3 - z3) % p
da = (d * a) % p
cb = (c * b) % p
t0 = (da + cb) % p
x5 = (t0 * t0) % p
t1 = (da - cb) % p
t2 = (t1 * t1) % p
z5 = (x1 * t2) % p
x4 = (aa * bb) % p
t3 = (a24 * e) % p
t4 = (bb + t3) % p
z4 = (e * t4) % p
return x4, z4, x5, z5
def conditional_swap(first: int, second: int, condition: int):
# Returns (second, first) if condition is true and (first, second) otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
first & true_mask
)
k = decode_scalar(private_scalar)
u = decode_coordinate(public_point)
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for i in reversed(range(256)):
bit = (k >> i) & 1
swap = bit ^ swap
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
swap = bit
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
x = pow(z_2, p - 2, p) * x_2 % p
return encode_coordinate(x)
def elligator2(point: bytes) -> bytes:
# map_to_curve_elligator2_curve25519 from
# https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt
def conditional_move(first: int, second: int, condition: bool):
# Returns second if condition is true and first otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask)
u = decode_coordinate(point)
tv1 = (u * u) % p
tv1 = (2 * tv1) % p
xd = (tv1 + 1) % p
x1n = (-J) % p
tv2 = (xd * xd) % p
gxd = (tv2 * xd) % p
gx1 = (J * tv1) % p
gx1 = (gx1 * x1n) % p
gx1 = (gx1 + tv2) % p
gx1 = (gx1 * x1n) % p
tv3 = (gxd * gxd) % p
tv2 = (tv3 * tv3) % p
tv3 = (tv3 * gxd) % p
tv3 = (tv3 * gx1) % p
tv2 = (tv2 * tv3) % p
y11 = pow(tv2, c4, p)
y11 = (y11 * tv3) % p
y12 = (y11 * c3) % p
tv2 = (y11 * y11) % p
tv2 = (tv2 * gxd) % p
e1 = tv2 == gx1
y1 = conditional_move(y12, y11, e1)
x2n = (x1n * tv1) % p
tv2 = (y1 * y1) % p
tv2 = (tv2 * gxd) % p
e3 = tv2 == gx1
xn = conditional_move(x2n, x1n, e3)
x = xn * pow(xd, p - 2, p) % p
return encode_coordinate(x)

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class MessageHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,490 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
DEFAULT_SESSION_ID: int = 0
if t.TYPE_CHECKING:
from ...debuglink import DebugLink
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2Channel(Channel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
handshake_hash: bytes
_has_valid_channel: bool = False
_features: messages.Features | None = None
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
self._has_valid_channel = True
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
if not self._has_valid_channel:
self._establish_new_channel(helper_debug)
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version_major=2,
protocol_version_minor=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
handshake_hash=self.handshake_hash,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = DEFAULT_SESSION_ID
self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _send_message(
self,
message: protobuf.MessageType,
session_id: int = DEFAULT_SESSION_ID,
):
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(session_id, message_type, message_data)
self._read_ack()
def _read_message(self, message_type: type[MT]) -> MT:
_, msg_type, msg_data = self.read_and_decrypt()
msg = self.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce)
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
self.channel_id = cid
self.device_properties = dp
def _send_channel_allocation_request(self, nonce: bytes):
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
nonce,
)
def _read_channel_allocation_response(
self, expected_nonce: bytes
) -> tuple[int, bytes]:
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, expected_nonce
):
raise Exception("Invalid channel allocation response.")
channel_id = int.from_bytes(payload[8:10], "big")
device_properties = payload[10:]
return (channel_id, device_properties)
def _do_handshake(
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack()
init_response = self._read_handshake_init_response()
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# TODO check noise_tag is valid
ck = self._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
credential,
host_static_privkey,
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
if header.ctrl_byte == 0x42:
if payload == b"\x05":
raise exceptions.DeviceLockedException()
if not header.is_handshake_init_response():
LOG.debug("Received message is not a valid handshake init response message")
click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
return payload
def _send_handshake_completion_request(
self,
host_ephemeral_pubkey: bytes,
host_ephemeral_privkey: bytes,
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
host_static_privkey: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials
if host_static_privkey is not None and credential is not None:
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
else:
credential = None
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(
temp_host_static_privkey
)
host_static_privkey = temp_host_static_privkey
host_static_pubkey = temp_host_static_pubkey
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
host_pairing_credential=credential,
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload[:-16])
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
self.handshake_hash = h
return ck
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
self._send_message(messages.ThpPairingRequest())
self._read_message(messages.ButtonRequest)
self._send_message(messages.ButtonAck())
if helper_debug is not None:
helper_debug.press_yes()
self._read_message(messages.ThpPairingRequestApproved)
self._send_message(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
)
)
self._read_message(messages.ThpEndResponse)
self._has_valid_channel = True
def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self):
LOG.debug("sending ack 1")
header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
self.sync_bit_send = 1 - self.sync_bit_send
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
# TODO fix this recursion
return self.read_and_decrypt()
if control_byte.is_error(header.ctrl_byte):
# TODO check for different channel
err = _get_error_from_int(raw_payload[0])
raise Exception("Received ThpError: " + err)
if not header.is_encrypted_transport():
click.echo(
"Trying to decrypt not encrypted message! ("
+ hexlify(header.to_bytes_init() + raw_payload).decode()
+ ")",
err=True,
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(
"--> Get sequence bit %d %s %s",
control_byte.get_seq_bit(header.ctrl_byte),
"from control byte",
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
)
if control_byte.get_seq_bit(header.ctrl_byte):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
session_id,
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[MessageHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
click.echo(
"Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
err=True,
)
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
click.echo(
"Received message is not a channel allocation response", err=True
)
return False
if len(payload) < 10:
click.echo("Invalid channel allocation response payload", err=True)
return False
if payload[:8] != original_nonce:
click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False
return True
def _get_error_from_int(error_code: int) -> str:
# TODO FIXME improve this (ThpErrorType)
if error_code == 1:
return "TRANSPORT BUSY"
if error_code == 2:
return "UNALLOCATED CHANNEL"
if error_code == 3:
return "DECRYPTION FAILED"
if error_code == 4:
return "INVALID DATA"
if error_code == 5:
return "DEVICE LOCKED"
raise Exception("Not Implemented error case")

View File

@ -0,0 +1,97 @@
import struct
from typing import Tuple
from .. import Transport
from ..thp import checksum
from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
transport.open()
buffer = bytearray(transport_payload)
if transport.CHUNK_SIZE is None:
transport.write_chunk(buffer)
return
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid))
data_len = header.data_length - checksum.CHECKSUM_LENGTH
msg_data = buffer[:data_len]
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
return (header, msg_data, chksum)
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]