feat(core): implement a new Trezor-Host protocol

M1nd3r/thp1
M1nd3r 2 months ago
parent 4b68431f7c
commit 5394ce8df7

@ -307,6 +307,7 @@ core unix frozen debug build:
needs: []
variables:
PYOPT: "0"
THP: "1"
script:
- $NIX_SHELL --run "poetry run make -C core build_unix_frozen"
artifacts:

@ -39,6 +39,7 @@ message Failure {
Failure_PinMismatch = 12;
Failure_WipeCodeMismatch = 13;
Failure_InvalidSession = 14;
Failure_ThpUnallocatedSession=15;
Failure_FirmwareError = 99;
}
}

@ -0,0 +1,216 @@
syntax = "proto2";
package hw.trezor.messages.thp;
// Sugar for easier handling in Java
option java_package = "com.satoshilabs.trezor.lib.protobuf";
option java_outer_classname = "TrezorMessageThp";
/**
* Numeric identifiers of pairing methods.
* @embed
*/
enum ThpPairingMethod {
NoMethod = 1; // Trust without MITM protection.
CodeEntry = 2; // User types code diplayed on Trezor into the host application.
QrCode = 3; // User scans code displayed on Trezor into host application.
NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
}
/**
* @embed
*/
message ThpDeviceProperties {
optional string internal_model = 1; // Internal model name e.g. "T2B1".
optional uint32 model_variant = 2; // Encodes the device properties such as color.
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
}
/**
* @embed
*/
message ThpHandshakeCompletionReqNoisePayload {
optional bytes host_pairing_credential = 1; // Host's pairing credential
repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host
}
/**
* Request: Ask device for a new session with given passphrase.
* @start
* @next ThpNewSession
*/
message ThpCreateNewSession{
optional string passphrase = 1;
optional bool on_device = 2; // User wants to enter passphrase on the device
optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only
}
/**
* Response: Contains session_id of the newly created session.
* @end
*/
message ThpNewSession{
optional uint32 new_session_id = 1;
}
/**
* Request: Start pairing process.
* @start
* @next ThpCodeEntryCommitment
* @next ThpPairingPreparationsFinished
*/
message ThpStartPairingRequest{
optional string host_name = 1; // Human-readable host name
}
/**
* Response: Pairing is ready for user input / OOB communication.
* @next ThpCodeEntryCpace
* @next ThpQrCodeTag
* @next ThpNfcUnidirectionalTag
*/
message ThpPairingPreparationsFinished{
}
/**
* Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment.
* @next ThpCodeEntryChallenge
*/
message ThpCodeEntryCommitment {
optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret
}
/**
* Response: Host responds to Trezor's Code Entry commitment with a challenge.
* @next ThpPairingPreparationsFinished
*/
message ThpCodeEntryChallenge {
optional bytes challenge = 1; // host's random 32-byte challenge
}
/**
* Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor.
* @next ThpCodeEntryCpaceTrezor
*/
message ThpCodeEntryCpaceHost {
optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key
}
/**
* Response: Trezor continues with the CPACE protocol.
* @next ThpCodeEntryTag
*/
message ThpCodeEntryCpaceTrezor {
optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key
}
/**
* Response: Host continues with the CPACE protocol.
* @next ThpCodeEntrySecret
*/
message ThpCodeEntryTag {
optional bytes tag = 2; // SHA-256 of shared secret
}
/**
* Response: Trezor finishes the CPACE protocol.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCodeEntrySecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected QR Code pairing option. Host sends a QR Tag.
* @next ThpQrCodeSecret
*/
message ThpQrCodeTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the QR secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpQrCodeSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag.
* @next ThpNfcUnidirectionalSecret
*/
message ThpNfcUnidirectionalTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the Unidirectioal NFC secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpNfcUnidirectionalSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: Host requests issuance of a new pairing credential.
* @start
* @next ThpCredentialResponse
*/
message ThpCredentialRequest {
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake.
}
/**
* Response: Trezor issues a new pairing credential.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCredentialResponse {
optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake.
optional bytes credential = 2; // The pairing credential issued by the Trezor to the host.
}
/**
* Request: Host requests transition to the encrypted traffic phase.
* @start
* @next ThpEndResponse
*/
message ThpEndRequest {}
/**
* Response: Trezor approves transition to the encrypted traffic phase
* @end
*/
message ThpEndResponse {}
/**
* Only for internal use.
* @embed
*/
message ThpCredentialMetadata {
optional string host_name = 1; // Human-readable host name
}
/**
* Only for internal use.
* @embed
*/
message ThpPairingCredential {
optional ThpCredentialMetadata cred_metadata = 1; // Credential metadata
optional bytes mac = 2; // Message authentication code generated by the Trezor
}
/**
* Only for internal use.
* @embed
*/
message ThpAuthenticatedCredentialData {
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake
optional ThpCredentialMetadata cred_metadata = 2; // Credential metadata
}

@ -42,6 +42,10 @@ extend google.protobuf.EnumValueOptions {
optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode
optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader
optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode
optional bool channel_in = 50009;
optional bool channel_out = 50010;
optional bool pairing_in = 50011;
optional bool pairing_out = 50012;
optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build
// (messages not marked bitcoin_only will be EXCLUDED)
@ -375,4 +379,26 @@ enum MessageType {
MessageType_SolanaAddress = 903 [(wire_out) = true];
MessageType_SolanaSignTx = 904 [(wire_in) = true];
MessageType_SolanaTxSignature = 905 [(wire_out) = true];
// THP
MessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true];
MessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true];
MessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true];
}

@ -133,9 +133,7 @@ WIRETYPE_ENTRY = c.Sequence(
)
# QDEF(MP_QSTR_copysign, 5171, 8, "copysign")
QDEF_RE = re.compile(
r'^QDEF\(MP_QSTR(\S+), ([0-9]+), ([0-9])+, "(.*)"\)$'
)
QDEF_RE = re.compile(r'^QDEF\(MP_QSTR(\S+), ([0-9]+), ([0-9])+, "(.*)"\)$')
@dataclass
@ -558,6 +556,8 @@ class RustBlobRenderer:
enums = []
cursor = 0
for enum in sorted(self.descriptor.enums, key=lambda e: e.name):
if enum.name == "MessageType":
continue
self.enum_map[enum.name] = cursor
enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value))
enums.append(enum_blob)

@ -290,10 +290,16 @@ build_unix: templates ## build unix port
build_unix_frozen: templates build_cross ## build unix port with frozen modules
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
PYOPT="$(PYOPT)" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 NEW_RENDERING="$(NEW_RENDERING)"
build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0")
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1
build_unix_debug: templates ## build unix port
$(SCONS) --max-drift=1 CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \

@ -653,6 +653,8 @@ if FROZEN:
else:
raise ValueError('Unknown layout')
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
@ -744,6 +746,9 @@ if FROZEN:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py'))
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py'))
@ -777,7 +782,7 @@ if FROZEN:
source_files = SOURCE_MOD + SOURCE_MOD_CRYPTO + SOURCE_FIRMWARE + SOURCE_MICROPYTHON + SOURCE_MICROPYTHON_SPEED + SOURCE_HAL
obj_program = []
obj_program.extend(env.Object(source=SOURCE_MOD))
obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO, CCFLAGS='$CCFLAGS -ftrivial-auto-var-init=zero'))
obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO))
if FEATURE_FLAGS["SECP256K1_ZKP"]:
obj_program.extend(env.Object(source=SOURCE_MOD_SECP256K1_ZKP, CCFLAGS='$CCFLAGS -Wno-unused-function'))
source_files.extend(SOURCE_MOD_SECP256K1_ZKP)

@ -696,6 +696,8 @@ if FROZEN:
else:
raise ValueError('Unknown layout')
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
@ -789,6 +791,9 @@ if FROZEN:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py'))
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py'))

@ -356,7 +356,7 @@ pub static mp_module_trezorproto: Module = obj_module! {
/// """Calculate length of encoding of the specified message."""
Qstr::MP_QSTR_encoded_length => obj_fn_1!(protobuf_len).as_obj(),
/// def encode(buffer: bytearray, msg: MessageType) -> int:
/// def encode(buffer: bytearray | memoryview, msg: MessageType) -> int:
/// """Encode the message into the specified buffer. Return length of
/// encoding."""
Qstr::MP_QSTR_encode => obj_fn_2!(protobuf_encode).as_obj()

@ -42,6 +42,6 @@ def encoded_length(msg: MessageType) -> int:
# rust/src/protobuf/obj.rs
def encode(buffer: bytearray, msg: MessageType) -> int:
def encode(buffer: bytearray | memoryview, msg: MessageType) -> int:
"""Encode the message into the specified buffer. Return length of
encoding."""

@ -47,6 +47,12 @@ storage
import storage
storage.cache
import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common
import storage.common
storage.debug
@ -133,6 +139,8 @@ trezor.enums.SafetyCheckLevel
import trezor.enums.SafetyCheckLevel
trezor.enums.SdProtectOperationType
import trezor.enums.SdProtectOperationType
trezor.enums.ThpPairingMethod
import trezor.enums.ThpPairingMethod
trezor.enums.WordRequestType
import trezor.enums.WordRequestType
trezor.enums
@ -209,6 +217,48 @@ trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.message_handler
import trezor.wire.message_handler
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.wire.thp
import trezor.wire.thp
trezor.wire.thp.alternating_bit_protocol
import trezor.wire.thp.alternating_bit_protocol
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.cpace
import trezor.wire.thp.cpace
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.handler_provider
import trezor.wire.thp.handler_provider
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages
trezor.wire.thp.transmission_loop
import trezor.wire.thp.transmission_loop
trezor.wire.thp.writer
import trezor.wire.thp.writer
trezor.wire.thp_v3
import trezor.wire.thp_v3
trezor.workflow
import trezor.workflow
apps
@ -293,6 +343,8 @@ apps.common.backup
import apps.common.backup
apps.common.backup_types
import apps.common.backup_types
apps.common.cache
import apps.common.cache
apps.common.cbor
import apps.common.cbor
apps.common.coininfo
@ -381,6 +433,14 @@ apps.misc.get_firmware_hash
import apps.misc.get_firmware_hash
apps.misc.sign_identity
import apps.misc.sign_identity
apps.thp
import apps.thp
apps.thp.create_session
import apps.thp.create_session
apps.thp.credential_manager
import apps.thp.credential_manager
apps.thp.pairing
import apps.thp.pairing
apps.workflow_handlers
import apps.workflow_handlers

@ -1,11 +1,19 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_codec as cache_codec
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from storage.cache_common import (
APP_COMMON_BUSY_DEADLINE_MS,
APP_COMMON_DERIVE_CARDANO,
APP_COMMON_SEED,
)
from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType
from trezor.messages import Success, UnlockPath
from trezor.ui.layouts import confirm_action
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
from . import workflow_handlers
@ -34,7 +42,7 @@ def busy_expiry_ms() -> int:
Returns the time left until the busy state expires or 0 if the device is not in the busy state.
"""
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS)
if busy_deadline_ms is None:
return 0
@ -199,13 +207,18 @@ def get_features() -> Features:
return f
async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
@check_thp_is_not_used
async def handle_Initialize(
msg: Initialize,
) -> Features:
session_id = cache_codec.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
# TODO change cardano derivation
# ctx = context.get_context()
if not utils.BITCOIN_ONLY:
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
have_seed = context.cache_is_set(APP_COMMON_SEED)
if (
have_seed
and msg.derive_cardano is not None
@ -213,14 +226,12 @@ async def handle_Initialize(msg: Initialize) -> Features:
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
storage_cache.end_current_session()
session_id = storage_cache.start_session()
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
if not have_seed:
storage_cache.set_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
features = get_features()
features.session_id = session_id
@ -249,16 +260,16 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
import utime
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline)
else:
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
workflow.close_others()
return Success()
async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session()
cache_codec.end_current_session()
return Success()
@ -273,7 +284,7 @@ async def handle_Ping(msg: Ping) -> Success:
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common import authorization
@ -286,11 +297,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
if handler is None:
return wire.unexpected_message()
return wire.message_handler.unexpected_message()
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
@ -298,7 +307,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
@ -339,9 +348,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
assert handler is not None
return await handler(req, msg) # type: ignore [Expected 1 positional argument]
@ -361,7 +368,7 @@ def set_homescreen() -> None:
set_default = workflow.set_default # local_cache_attribute
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS):
from apps.homescreen import busyscreen
set_default(busyscreen)
@ -390,7 +397,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin():
config.lock()
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)
set_homescreen()
if interrupt_workflow:
workflow.close_others()
@ -426,7 +433,7 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False
set_homescreen()
wire.remove_filter(_pinlock_filter)
remove_filter(_pinlock_filter)
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
@ -479,4 +486,4 @@ def boot() -> None:
backup.activate_repeated_backup()
if not config.is_unlocked():
# pinlocked handler should always be the last one
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)

@ -1,7 +1,7 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor.wire import DataError
from trezor.wire import DataError, context
from .. import writers
@ -26,7 +26,7 @@ class PaymentRequestVerifier:
def __init__(
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
@ -42,9 +42,9 @@ class PaymentRequestVerifier:
if msg.nonce:
nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce:
if context.cache_get(APP_COMMON_NONCE) != nonce:
raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE)
context.cache_delete(APP_COMMON_NONCE)
else:
nonce = b""
if msg.memos:

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from storage import cache, device
import storage.device as device
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_CARDANO_ICARUS_TREZOR_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
from trezor import wire
from trezor.crypto import cardano
@ -15,6 +20,7 @@ if TYPE_CHECKING:
from trezor import messages
from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path
@ -110,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None:
def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized()
assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO)
assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed
@ -132,14 +138,15 @@ def derive_and_store_secrets(passphrase: str) -> None:
else:
icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType
from trezor.wire import context
from apps.common.seed import derive_and_store_roots
from apps.common.seed import derive_and_store_roots_legacy
if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
@ -148,19 +155,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed))
if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO):
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
raise wire.ProcessError("Cardano derivation is not enabled for this session")
if derivation_type == CardanoDerivationType.ICARUS:
cache_entry = cache.APP_CARDANO_ICARUS_SECRET
cache_entry = APP_CARDANO_ICARUS_SECRET
else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET
# _get_secret
secret = cache.get(cache_entry)
secret = context.cache_get(cache_entry)
if secret is None:
await derive_and_store_roots()
secret = cache.get(cache_entry)
await derive_and_store_roots_legacy()
secret = context.cache_get(cache_entry)
assert secret is not None
root = cardano.from_secret(secret)

@ -1,23 +1,21 @@
from typing import Iterable
import storage.cache as storage_cache
from storage.cache_common import (
APP_COMMON_AUTHORIZATION_DATA,
APP_COMMON_AUTHORIZATION_TYPE,
)
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire import context
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool:
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None
return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
def set(auth_message: protobuf.MessageType) -> None:
@ -29,16 +27,16 @@ def set(auth_message: protobuf.MessageType) -> None:
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, stored_auth_type)
@ -49,7 +47,7 @@ def is_set_any_session(auth_type: MessageType) -> bool:
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None:
return ()
@ -57,5 +55,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None:
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE)
context.cache_delete(APP_COMMON_AUTHORIZATION_DATA)

@ -1,25 +1,27 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as storage_cache
from trezor import wire
from trezor.enums import MessageType
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
if TYPE_CHECKING:
from trezor.wire import Handler, Msg
def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
return context.cache_get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup():
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter)
context.cache_set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
filters.append(_repeated_backup_filter)
def deactivate_repeated_backup():
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter)
context.cache_delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from trezor.wire import context
if TYPE_CHECKING:
from typing import Callable, ParamSpec
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = context.cache_get(key)
if value is None:
value = func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator

@ -1,15 +1,48 @@
from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from trezor.wire import DataError
_MAX_PASSPHRASE_LEN = const(50)
if TYPE_CHECKING:
from trezor.messages import ThpCreateNewSession
def is_enabled() -> bool:
return storage_device.is_passphrase_enabled()
async def get_passphrase(msg: ThpCreateNewSession) -> str:
if not is_enabled():
return ""
if msg.on_device or storage_device.get_passphrase_always_on_device():
passphrase = await _get_on_device()
else:
passphrase = msg.passphrase or ""
if passphrase:
await _handle_displaying_passphrase_from_host(passphrase)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _get_on_device() -> str:
from trezor import workflow
from trezor.ui.layouts import request_passphrase_on_device
workflow.close_others() # request exclusive UI access
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
return passphrase
@check_thp_is_not_used
async def get() -> str:
from trezor import workflow
@ -29,8 +62,8 @@ async def get() -> str:
return passphrase
@check_thp_is_not_used
async def _request_on_host() -> str:
from trezor import TR
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
@ -55,29 +88,34 @@ async def _request_on_host() -> str:
# non-empty passphrase
if passphrase:
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)
await _handle_displaying_passphrase_from_host(passphrase)
return passphrase
async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
from trezor import TR
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)

@ -1,9 +1,10 @@
import utime
from typing import Any, NoReturn
import storage.cache as storage_cache
from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK
from trezor import TR, config, utils, wire
from trezor.ui.layouts import show_error_and_raise
from trezor.wire import context
async def _request_sd_salt(
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None:
now = utime.ticks_ms()
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
_DEF_ARG_PIN_ENTER: str = TR.pin__enter
@ -91,7 +92,7 @@ async def verify_user_pin(
) -> None:
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if (

@ -1,15 +1,15 @@
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
def read_setting() -> SafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

@ -1,18 +1,33 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE
from trezor import utils
from trezor.crypto import hmac
from trezor.wire import context
from trezor.wire.context import get_context
from trezor.wire.errors import DataError
from apps.common import cache
from . import mnemonic
from .passphrase import get as get_passphrase
from .passphrase import get as get_passphrase_legacy
from .passphrase import get_passphrase as get_passphrase
if TYPE_CHECKING:
from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path
if not utils.BITCOIN_ONLY:
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
class Slip21Node:
"""
@ -45,54 +60,71 @@ class Slip21Node:
return Slip21Node(data=self.data)
async def get_seed() -> bytes:
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
if not utils.BITCOIN_ONLY:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots() -> None:
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = storage_cache.get_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
if ctx.cache.is_set(APP_COMMON_SEED):
raise Exception("Seed is already set!")
if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
raise Exception("Cardano icarus secret is already set!")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if msg.derive_cardano:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(ctx, passphrase)
@check_thp_is_not_used
async def derive_and_store_roots_legacy() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
ctx = get_context()
need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
need_cardano_secret = ctx.cache.get_bool(
APP_COMMON_DERIVE_CARDANO
) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase()
passphrase = await get_passphrase_legacy()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)
derive_and_store_secrets(ctx, passphrase)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage_device.is_initialized():
raise Exception("Device is not initialized")

@ -44,7 +44,7 @@ if __debug__:
layout_change_chan = loop.chan()
DEBUG_CONTEXT: context.Context | None = None
DEBUG_CONTEXT: context.CodecContext | None = None
LAYOUT_WATCHER_NONE = 0
LAYOUT_WATCHER_STATE = 1

@ -5,10 +5,11 @@ if TYPE_CHECKING:
async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto import random
from trezor.messages import Nonce
from trezor.wire.context import cache_set
nonce = random.bytes(32)
cache.set(cache.APP_COMMON_NONCE, nonce)
cache_set(APP_COMMON_NONCE, nonce)
return Nonce(nonce=nonce)

@ -36,7 +36,7 @@ async def recovery_process() -> Success:
recovery_type = storage_recovery.get_type()
wire.AVOID_RESTARTING_FOR = (
wire.message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
@ -57,7 +57,7 @@ async def _continue_repeated_backup() -> None:
from apps.common import backup
from apps.management.backup_device import perform_backup
wire.AVOID_RESTARTING_FOR = (
wire.message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,

@ -59,14 +59,15 @@ async def _init_step(
) -> MoneroLiveRefreshStartAck:
import storage.cache as storage_cache
from trezor.messages import MoneroLiveRefreshStartAck
from trezor.wire import context
from apps.common import paths
await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
if not context.cache_get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh()
storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True)
context.cache_set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -0,0 +1,51 @@
from trezor import log, loop
from trezor.enums import FailureType
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.wire.context import get_context
from trezor.wire.errors import ActionCancelled, DataError
from trezor.wire.thp import SessionState
async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
ctx = get_context()
# Assert that context `ctx` is ManagementSessionContext
from trezor.wire.thp.session_context import ManagementSessionContext
assert isinstance(ctx, ManagementSessionContext)
channel = ctx.channel
# Do not use `ctx` beyond this point, as it is techically
# allowed to change inbetween await statements
new_session = create_new_session(channel)
try:
await derive_and_store_roots(new_session, message)
except DataError as e:
return Failure(code=FailureType.DataError, message=e.message)
except ActionCancelled as e:
return Failure(code=FailureType.ActionCancelled, message=e.message)
# TODO handle other errors
# TODO handle BITCOIN_ONLY
new_session.set_session_state(SessionState.ALLOCATED)
channel.sessions[new_session.session_id] = new_session
loop.schedule(new_session.handle())
new_session_id: int = new_session.session_id
# await get_seed() TODO
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
message.passphrase if message.passphrase is not None else "",
new_session.session_id,
str(channel.sessions),
)
return ThpNewSession(new_session_id=new_session_id)

@ -0,0 +1,110 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.crypto import hmac
from trezor.messages import (
ThpAuthenticatedCredentialData,
ThpCredentialMetadata,
ThpPairingCredential,
)
from trezor.wire import message_handler
if TYPE_CHECKING:
from apps.common.paths import Slip21Path
def derive_cred_auth_key() -> bytes:
"""
Derive current credential authentication mac-ing key from device secret.
"""
from storage.device import get_cred_auth_key_counter, get_device_secret
from apps.common.seed import Slip21Node
# Derive the key using SLIP-21 https://github.com/satoshilabs/slips/blob/master/slip-0021.md,
# the derivation path is m/"Credential authentication key"/(counter 4-byte BE)
thp_secret = get_device_secret()
label = b"Credential authentication key"
counter = get_cred_auth_key_counter()
path: Slip21Path = [label, counter]
symmetric_key_node: Slip21Node = Slip21Node(thp_secret)
symmetric_key_node.derive_path(path)
cred_auth_key = symmetric_key_node.key()
return cred_auth_key
def invalidate_cred_auth_key() -> None:
from storage.device import increment_cred_auth_key_counter
increment_cred_auth_key_counter()
def issue_credential(
host_static_pubkey: bytes,
credential_metadata: ThpCredentialMetadata,
) -> bytes:
"""
Issue a pairing credential binded to the provided host static public key
and credential metadata.
"""
return _issue_credential(
derive_cred_auth_key(), host_static_pubkey, credential_metadata
)
def validate_credential(
encoded_pairing_credential_message: bytes,
host_static_pubkey: bytes,
) -> bool:
"""
Validate a pairing credential binded to the provided host static public key.
"""
return _validate_credential(
derive_cred_auth_key(), encoded_pairing_credential_message, host_static_pubkey
)
def _issue_credential(
cred_auth_key: bytes,
host_static_pubkey: bytes,
credential_metadata: ThpCredentialMetadata,
) -> bytes:
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,
cred_metadata=credential_metadata,
)
authenticated_credential_data = _encode_message_into_new_buffer(proto_msg)
mac = hmac(hmac.SHA256, cred_auth_key, authenticated_credential_data).digest()
proto_msg = ThpPairingCredential(cred_metadata=credential_metadata, mac=mac)
credential_raw = _encode_message_into_new_buffer(proto_msg)
return credential_raw
def _validate_credential(
cred_auth_key: bytes,
encoded_pairing_credential_message: bytes,
host_static_pubkey: bytes,
) -> bool:
expected_type = protobuf.type_for_name("ThpPairingCredential")
credential = message_handler.wrap_protobuf_load(
encoded_pairing_credential_message, expected_type
)
assert ThpPairingCredential.is_type_of(credential)
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,
cred_metadata=credential.cred_metadata,
)
authenticated_credential_data = _encode_message_into_new_buffer(proto_msg)
mac = hmac(hmac.SHA256, cred_auth_key, authenticated_credential_data).digest()
return mac == credential.mac
def _encode_message_into_new_buffer(msg: protobuf.MessageType) -> bytes:
msg_len = protobuf.encoded_length(msg)
new_buffer = bytearray(msg_len)
protobuf.encode(new_buffer, msg)
return new_buffer

@ -0,0 +1,379 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import loop, protobuf
from trezor.crypto.hashlib import sha256
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpCredentialMetadata,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnidirectionalSecret,
ThpNfcUnidirectionalTag,
ThpPairingPreparationsFinished,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire.errors import ActionCancelled, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto
from trezor.wire.thp.pairing_context import PairingContext
from .credential_manager import issue_credential
if __debug__:
from trezor import log
if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
#
# Helpers - decorators
def check_state_and_log(
*allowed_states: ChannelState,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_state(context, *allowed_states)
if __debug__:
try:
log.debug(__name__, "started %s", f.__name__)
except AttributeError:
log.debug(
__name__,
"started a function that cannot be named, because it raises AttributeError, eg. closure",
)
return f(context, *args, **kwargs)
return inner
return decorator
def check_method_is_allowed(
pairing_method: ThpPairingMethod,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_method_is_allowed(context, pairing_method)
return f(context, *args, **kwargs)
return inner
return decorator
#
# Pairing handlers
@check_state_and_log(ChannelState.TP1)
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpStartPairingRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
ctx.host_name = message.host_name or ""
skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
if skip_pairing:
return await _end_pairing(ctx)
await _prepare_pairing(ctx)
await ctx.write(ThpPairingPreparationsFinished())
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await show_display_data(ctx, _get_possible_pairing_methods(ctx))
# TODO disable NFC (if enabled)
response = await _handle_different_pairing_methods(ctx, response)
while ThpCredentialRequest.is_type_of(response):
response = await _handle_credential_request(ctx, response)
return await _handle_end_request(ctx, response)
async def _prepare_pairing(ctx: PairingContext) -> None:
if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
await _handle_code_entry_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.QrCode):
_handle_qr_code_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
_handle_nfc_unidirectional_is_included(ctx)
async def show_display_data(ctx: PairingContext, expected_types: Container[int] = ()):
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task)
result = await race
if read_task in race.finished:
return result
if cancel_task in race.finished:
raise ActionCancelled
else:
return Exception("Should not happen") # TODO
@check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
commitment = sha256(ctx.secret).digest()
challenge_message = await ctx.call( # noqa: F841
ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
)
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
if not ThpCodeEntryChallenge.is_type_of(challenge_message):
raise UnexpectedMessage("Unexpected message")
if challenge_message.challenge is None:
raise Exception("Invalid message")
code_code_entry_hash = sha256(
challenge_message.challenge
+ ctx.secret
+ bytes("PairingMethod_CodeEntry", "utf-8")
).digest() # TODO add handshake hash
ctx.display_data.code_code_entry = (
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
ctx.display_data.display_code_entry = True
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_qr_code_is_included(ctx: PairingContext) -> None:
ctx.display_data.code_qr_code = sha256(
ctx.secret + bytes("PairingMethod_QrCode", "utf-8")
).digest()[
:16
] # TODO add handshake hash
ctx.display_data.display_qr_code = True
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
ctx.display_data.code_nfc_unidirectional = sha256(
ctx.secret + bytes("PairingMethod_NfcUnidirectional", "utf-8")
).digest()[
:16
] # TODO add handshake hash
ctx.display_data.display_nfc_unidirectional = True
@check_state_and_log(ChannelState.TP3)
async def _handle_different_pairing_methods(
ctx: PairingContext, response: protobuf.MessageType
) -> protobuf.MessageType:
if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise UnexpectedMessage("Unexpected message")
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
from trezor.wire.thp.cpace import Cpace
# TODO check that ThpCodeEntryCpaceHost message is valid
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryCpaceHost)
if message.cpace_host_public_key is None:
raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
ctx.cpace = Cpace(message.cpace_host_public_key)
assert ctx.display_data.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
ThpCodeEntryTag,
)
return await _handle_code_entry_tag(ctx, response)
@check_state_and_log(ChannelState.TP4)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryTag)
expected_tag = sha256(ctx.cpace.shared_secret).digest()
if expected_tag != message.tag:
print(
"expected code entry tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code entry shared secret:",
hexlify(ctx.cpace.shared_secret).decode(),
) # TODO remove after testing
raise ThpError("Unexpected Entry Code Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpCodeEntrySecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.QrCode)
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag)
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
if expected_tag != message.tag:
print(
"expected qr code tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected QR Code Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpQrCodeSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpNfcUnidirectionalTag)
expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
if expected_tag != message.tag:
print(
"expected nfc tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected NFC Unidirectional Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
async def _handle_secret_reveal(
ctx: PairingContext,
msg: protobuf.MessageType,
) -> protobuf.MessageType:
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
return await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
ctx.secret
if not ThpCredentialRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
if message.host_static_pubkey is None:
raise Exception("Invalid message") # TODO change failure type
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
return await ctx.call_any(
ThpCredentialResponse(
trezor_static_pubkey=trezor_static_pubkey, credential=credential
),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpEndRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
return await _end_pairing(ctx)
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
#
# Helpers - checkers
def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
if ctx.channel_ctx.get_channel_state() not in allowed_states:
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel_ctx.selected_pairing_methods
#
# Helpers - getters
def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
return tuple(
_get_message_type_for_method(method)
for method in ctx.channel_ctx.selected_pairing_methods
)
def _get_message_type_for_method(method: int) -> int:
if method is ThpPairingMethod.CodeEntry:
return MessageType.ThpCodeEntryCpaceHost
if method is ThpPairingMethod.NFC_Unidirectional:
return MessageType.ThpNfcUnidirectionalTag
if method is ThpPairingMethod.QrCode:
return MessageType.ThpQrCodeTag
raise ValueError("Unexpected pairing method - no message type available")

@ -374,6 +374,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
desc_cont = frame_cont()
read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for incoming command indefinitely
buf = await read
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)
@ -409,9 +410,12 @@ async def _read_cmd(iface: HID) -> Cmd | None:
else:
data = data[:bcnt]
# set a timeout for subsequent reads
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt:
buf = await loop.race(read, loop.sleep(_CTAP_HID_TIMEOUT_MS))
if not isinstance(buf, bytes):
try:
buf = await read
except loop.Timeout:
if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT")
await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface)
@ -494,7 +498,9 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
if offset < datalen:
frm = overlay_struct(buf, cont_desc)
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
write = loop.wait(
iface.iface_num() | io.POLL_WRITE, timeout_ms=_CTAP_HID_TIMEOUT_MS
)
while offset < datalen:
frm.seq = seq
copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
@ -502,10 +508,7 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
if copied < _FRAME_CONT_SIZE:
frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied)
while True:
ret = await loop.race(write, loop.sleep(_CTAP_HID_TIMEOUT_MS))
if ret is not None:
raise TimeoutError
await write
if iface.write(buf) > 0:
break
seq += 1

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezorio import WireInterface
from trezor.wire import Handler, Msg
@ -207,7 +205,7 @@ def _find_message_handler_module(msg_type: int) -> str:
raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
def find_registered_handler(msg_type: int) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
return workflow_handlers[msg_type]

@ -54,7 +54,6 @@ async def bootscreen() -> None:
"""
while True:
try:
if can_lock_device():
enforce_welcome_screen_duration()
if utils.INTERNAL_MODEL == "T2T1":

@ -1,153 +1,10 @@
import builtins
import gc
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int] # field sizes
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
@ -156,210 +13,61 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
if utils.USE_THP:
from storage import cache_thp
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_PROTOCOL_CACHE = cache_codec
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def _get_length(key: int) -> int:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
elif _active_session_idx is None:
raise InvalidSessionError
else:
return _SESSIONS[_active_session_idx].fields[key]
def set_int(key: int, value: int) -> None:
length = _get_length(key)
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
def set_bool(key: int, value: bool) -> None:
assert _get_length(key) == 0 # skipping get_length in production build
if value:
set(key, b"")
else:
delete(key)
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
encoded = get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
gc.collect()
def get_bool(key: int) -> bool: # noqa: F811
return get(key) is not None
def clear_all() -> None:
global autolock_last_touch
autolock_last_touch = None
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
values = builtins.set()
for session in sessions:
encoded = session.get(key)
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
def get_sessionless_cache() -> SessionlessCache:
return _SESSIONLESS_CACHE
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec, TypeVar
from typing import Callable, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper
def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]:
"""A type-safe decorator to raise an exception when the function is called with THP enabled.
return decorator
This decorator should be removed after the caches for Codec_v1 and THP are properly refactored and separated.
"""
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if utils.USE_THP:
raise Exception("Cannot call this function with the new THP enabled")
return f(*args, **kwargs)
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None
return inner

@ -0,0 +1,145 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
initialize()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

@ -0,0 +1,179 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Cache keys for THP channel
if utils.USE_THP:
CHANNEL_HANDSHAKE_HASH = const(0)
CHANNEL_KEY_RECEIVE = const(1)
CHANNEL_KEY_SEND = const(2)
CHANNEL_NONCE_RECEIVE = const(3)
CHANNEL_NONCE_SEND = const(4)
# Keys that are valid across sessions
SESSIONLESS_FLAG = const(128)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def get_bool(self, key: int) -> bool: # noqa: F811
return self.get(key) is not None
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
encoded = self.get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
def set_bool(self, key: int, value: bool) -> None:
utils.ensure(
self._get_length(key) == 0
) # skipping get_length in production build
if value:
self.set(key, b"")
else:
self.delete(key)
def set_int(self, key: int, value: int) -> None:
length = self.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
self.set(key, encoded)
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
def _get_length(self, key: int) -> int:
utils.ensure(key < len(self.fields))
return self.fields[key]
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
return super().get(key & ~SESSIONLESS_FLAG, default)
def get_bool(self, key: int) -> bool: # noqa: F811
return super().get_bool(key & ~SESSIONLESS_FLAG)
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
return super().get_int(key & ~SESSIONLESS_FLAG, default)
def is_set(self, key: int) -> bool:
return super().is_set(key & ~SESSIONLESS_FLAG)
def set(self, key: int, value: bytes) -> None:
super().set(key & ~SESSIONLESS_FLAG, value)
def set_bool(self, key: int, value: bool) -> None:
super().set_bool(key & ~SESSIONLESS_FLAG, value)
def set_int(self, key: int, value: int) -> None:
super().set_int(key & ~SESSIONLESS_FLAG, value)
def delete(self, key: int) -> None:
super().delete(key & ~SESSIONLESS_FLAG)
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)

@ -0,0 +1,325 @@
import builtins
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar # pyright: ignore[reportShadowedImports]
T = TypeVar("T")
if __debug__:
from trezor import log
# THP specific constants
_MAX_CHANNELS_COUNT = const(10)
_MAX_SESSIONS_COUNT = const(20)
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2)
SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(65535)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
MANAGEMENT_SESSION_ID = const(0)
class ConnectionCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.last_usage = 0
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
self.last_usage = 0
super().clear()
class ChannelCache(ConnectionCache):
def __init__(self) -> None:
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x00
self.fields = (
32, # CHANNEL_HANDSHAKE_HASH
32, # CHANNEL_KEY_RECEIVE
32, # CHANNEL_KEY_SEND
8, # CHANNEL_NONCE_RECEIVE
8, # CHANNEL_NONCE_SEND
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
) # Set state to UNALLOCATED
# TODO clear all keys
super().clear()
class SessionThpCache(ConnectionCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.session_id[:] = b""
super().clear()
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
def initialize() -> None:
global _CHANNELS
global _SESSIONS
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
initialize()
# THP vars
_next_unauthenicated_session_index: int = 0 # TODO remove
# First unauthenticated channel will have index 0
_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659 # TODO change to random value on start
def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
if len(iface) != _WIRE_INTERFACE_LENGTH:
raise Exception("Invalid WireInterface (encoded) length")
new_cid = get_next_channel_id()
index = _get_next_unauthenticated_channel_index()
# clear sessions from replaced channel
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
old_cid = _CHANNELS[index].channel_id
clear_sessions_with_channel_id(old_cid)
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface[:] = bytearray(iface)
return _CHANNELS[index]
def get_all_allocated_channels() -> list[ChannelCache]:
_list: list[ChannelCache] = []
for channel in _CHANNELS:
if _get_channel_state(channel) != _UNALLOCATED_STATE:
_list.append(channel)
return _list
def get_all_allocated_sessions() -> list[SessionThpCache]:
if __debug__:
from trezor.utils import get_bytes_as_str
_list: list[SessionThpCache] = []
for session in _SESSIONS:
if _get_session_state(session) != _UNALLOCATED_STATE:
_list.append(session)
if __debug__:
log.debug(
__name__,
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
get_bytes_as_str(session.channel_id),
get_bytes_as_str(session.session_id),
)
elif __debug__:
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))
return _list
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
if len(key) != KEY_LENGTH:
raise Exception("Invalid key length")
channel.host_ephemeral_pubkey = key
def get_new_session(channel: ChannelCache):
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
_SESSIONS[index] = SessionThpCache()
_SESSIONS[index].channel_id[:] = channel.channel_id
_SESSIONS[index].session_id[:] = new_sid
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
channel.last_usage = (
_get_usage_counter_and_increment()
) # increment also use of the channel so it does not get replaced
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
return _SESSIONS[index]
def _get_usage_counter() -> int:
global _usage_counter
return _usage_counter
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
return _usage_counter
def _get_next_unauthenticated_channel_index() -> int:
idx = _get_unallocated_channel_index()
if idx is not None:
return idx
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_next_session_index() -> int:
idx = _get_unallocated_session_index()
if idx is not None:
return idx
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_unallocated_session_index() -> int | None:
for i in range(_MAX_SESSIONS_COUNT):
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
if channel is None:
return _UNALLOCATED_STATE
return int.from_bytes(channel.state, "big")
def _get_session_state(session: SessionThpCache) -> int:
if session is None:
return _UNALLOCATED_STATE
return int.from_bytes(session.state, "big")
def get_next_channel_id() -> bytes:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def get_next_session_id(channel: ChannelCache) -> bytes:
while True:
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
for session in _SESSIONS:
if session.channel_id == channel.channel_id:
if session.session_id == channel.session_id_counter:
return False
return True
def _is_cid_unique() -> bool:
for session in _SESSIONS:
if cid_counter == _get_cid(session):
return False
return True
def _get_cid(session: SessionThpCache) -> int:
return int.from_bytes(session.session_id[2:], "big")
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
):
lru_counter = _get_usage_counter()
lru_item_index = 0
for i in range(max_count):
if list[i].last_usage < lru_counter:
lru_counter = list[i].last_usage
lru_item_index = i
return lru_item_index
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_sessions_with_channel_id(channel_id: bytes):
for session in _SESSIONS:
if session.channel_id == channel_id:
session.clear()
def clear_all() -> None:
for session in _SESSIONS:
session.clear()

@ -16,4 +16,5 @@ NotInitialized = 11
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
FirmwareError = 99

@ -95,6 +95,24 @@ DebugLinkEraseSdCard = 9005
DebugLinkWatchLayout = 9006
DebugLinkResetDebugEvents = 9007
DebugLinkOptigaSetSecMax = 9008
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
if not utils.BITCOIN_ONLY:
SetU2FCounter = 63
GetNextU2FCounter = 80

@ -0,0 +1,8 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

@ -262,6 +262,24 @@ if TYPE_CHECKING:
SolanaAddress = 903
SolanaSignTx = 904
SolanaTxSignature = 905
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class FailureType(IntEnum):
UnexpectedMessage = 1
@ -278,6 +296,7 @@ if TYPE_CHECKING:
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
FirmwareError = 99
class ButtonRequestType(IntEnum):
@ -570,3 +589,9 @@ if TYPE_CHECKING:
Yay = 0
Nay = 1
Pass = 2
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

@ -14,9 +14,9 @@ from typing import TYPE_CHECKING
from trezor import io, log
if TYPE_CHECKING:
from typing import Any, Awaitable, Callable, Coroutine, Generator
from typing import Any, Awaitable, Callable, Coroutine, Generator, Union
Task = Coroutine | Generator
Task = Union[Coroutine, Generator, "wait"]
AwaitableTask = Task | Awaitable
Finalizer = Callable[[Task, Any], None]
@ -202,6 +202,13 @@ class Syscall:
pass
class Timeout(Exception):
pass
_TIMEOUT_ERROR = Timeout()
class sleep(Syscall):
"""Pause current task and resume it after given delay.
@ -233,11 +240,39 @@ class wait(Syscall):
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event
"""
def __init__(self, msg_iface: int) -> None:
_DO_NOT_RESCHEDULE = Syscall()
def __init__(self, msg_iface: int, timeout_ms: int | None = None) -> None:
self.msg_iface = msg_iface
self.timeout_ms = timeout_ms
self.task: Task | None = None
def handle(self, task: Task) -> None:
pause(task, self.msg_iface)
self.task = task
pause(self, self.msg_iface)
if self.timeout_ms is not None:
deadline = utime.ticks_add(utime.ticks_ms(), self.timeout_ms)
schedule(self, _TIMEOUT_ERROR, deadline)
def send(self, __value: Any) -> Any:
assert self.task is not None
self.close()
_step(self.task, __value)
return self._DO_NOT_RESCHEDULE
throw = send
def close(self) -> None:
_queue.discard(self)
if self.msg_iface in _paused:
_paused[self.msg_iface].discard(self)
def __iter__(self) -> Generator:
try:
return (yield self)
finally:
# whichever way we got here, we must be removed from the paused list
self.close()
_type_gen: type[Generator] = type((lambda: (yield))())

@ -65,6 +65,7 @@ if TYPE_CHECKING:
from trezor.enums import StellarSignerType # noqa: F401
from trezor.enums import TezosBallotType # noqa: F401
from trezor.enums import TezosContractType # noqa: F401
from trezor.enums import ThpPairingMethod # noqa: F401
from trezor.enums import WordRequestType # noqa: F401
class BinanceGetAddress(protobuf.MessageType):
@ -6117,6 +6118,324 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
return isinstance(msg, cls)
class ThpDeviceProperties(protobuf.MessageType):
internal_model: "str | None"
model_variant: "int | None"
bootloader_mode: "bool | None"
protocol_version: "int | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
internal_model: "str | None" = None,
model_variant: "int | None" = None,
bootloader_mode: "bool | None" = None,
protocol_version: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
return isinstance(msg, cls)
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
host_pairing_credential: "bytes | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
host_pairing_credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
return isinstance(msg, cls)
class ThpCreateNewSession(protobuf.MessageType):
passphrase: "str | None"
on_device: "bool | None"
derive_cardano: "bool | None"
def __init__(
self,
*,
passphrase: "str | None" = None,
on_device: "bool | None" = None,
derive_cardano: "bool | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
return isinstance(msg, cls)
class ThpNewSession(protobuf.MessageType):
new_session_id: "int | None"
def __init__(
self,
*,
new_session_id: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
return isinstance(msg, cls)
class ThpStartPairingRequest(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
return isinstance(msg, cls)
class ThpPairingPreparationsFinished(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
return isinstance(msg, cls)
class ThpCodeEntryCommitment(protobuf.MessageType):
commitment: "bytes | None"
def __init__(
self,
*,
commitment: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
return isinstance(msg, cls)
class ThpCodeEntryChallenge(protobuf.MessageType):
challenge: "bytes | None"
def __init__(
self,
*,
challenge: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceHost(protobuf.MessageType):
cpace_host_public_key: "bytes | None"
def __init__(
self,
*,
cpace_host_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
cpace_trezor_public_key: "bytes | None"
def __init__(
self,
*,
cpace_trezor_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
return isinstance(msg, cls)
class ThpCodeEntryTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
return isinstance(msg, cls)
class ThpCodeEntrySecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
return isinstance(msg, cls)
class ThpQrCodeTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
return isinstance(msg, cls)
class ThpQrCodeSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
return isinstance(msg, cls)
class ThpCredentialRequest(protobuf.MessageType):
host_static_pubkey: "bytes | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
return isinstance(msg, cls)
class ThpCredentialResponse(protobuf.MessageType):
trezor_static_pubkey: "bytes | None"
credential: "bytes | None"
def __init__(
self,
*,
trezor_static_pubkey: "bytes | None" = None,
credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
return isinstance(msg, cls)
class ThpEndRequest(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
return isinstance(msg, cls)
class ThpEndResponse(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
return isinstance(msg, cls)
class ThpCredentialMetadata(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialMetadata"]:
return isinstance(msg, cls)
class ThpPairingCredential(protobuf.MessageType):
cred_metadata: "ThpCredentialMetadata | None"
mac: "bytes | None"
def __init__(
self,
*,
cred_metadata: "ThpCredentialMetadata | None" = None,
mac: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingCredential"]:
return isinstance(msg, cls)
class ThpAuthenticatedCredentialData(protobuf.MessageType):
host_static_pubkey: "bytes | None"
cred_metadata: "ThpCredentialMetadata | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
cred_metadata: "ThpCredentialMetadata | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpAuthenticatedCredentialData"]:
return isinstance(msg, cls)
class WebAuthnListResidentCredentials(protobuf.MessageType):
@classmethod

@ -54,7 +54,7 @@ class RustLayout(ui.Layout):
assert msg is None
def _paint(self) -> None:
import storage.cache as storage_cache
import storage.cache_common as storage_cache
painted = self.layout.paint()

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as storage_cache
import trezorui2
from trezor import TR, ui
@ -139,11 +139,13 @@ class Busyscreen(HomescreenBase):
)
async def __iter__(self) -> Any:
from trezor.wire import context
from apps.base import set_homescreen
# Handle timeout.
result = await super().__iter__()
assert result == trezorui2.CANCELLED
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
return result

@ -55,14 +55,14 @@ class RustLayout(LayoutParentType[T]):
assert msg is None
def _paint(self) -> None:
import storage.cache as storage_cache
import storage.cache_common as cache_common
painted = self.layout.paint()
if painted:
ui.refresh()
if storage_cache.homescreen_shown is not None and painted:
storage_cache.homescreen_shown = None
if cache_common.homescreen_shown is not None and painted:
cache_common.homescreen_shown = None
if __debug__:
from trezor.enums import DebugPhysicalButton

@ -1,8 +1,9 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as cache_common
import trezorui2
from trezor import TR, ui
from trezor.wire import context
from . import RustLayout
@ -23,15 +24,15 @@ class HomescreenBase(RustLayout):
ui.refresh()
def _first_paint(self) -> None:
if storage_cache.homescreen_shown is not self.RENDER_INDICATOR:
if cache_common.homescreen_shown is not self.RENDER_INDICATOR:
super()._first_paint()
storage_cache.homescreen_shown = self.RENDER_INDICATOR
cache_common.homescreen_shown = self.RENDER_INDICATOR
else:
self._paint()
class Homescreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.HOMESCREEN_ON
RENDER_INDICATOR = cache_common.HOMESCREEN_ON
def __init__(
self,
@ -47,7 +48,7 @@ class Homescreen(HomescreenBase):
elif notification_is_error:
level = 0
skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR
skip = cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_homescreen(
label=label,
@ -73,7 +74,7 @@ class Homescreen(HomescreenBase):
class Lockscreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.LOCKSCREEN_ON
RENDER_INDICATOR = cache_common.LOCKSCREEN_ON
def __init__(
self,
@ -82,9 +83,7 @@ class Lockscreen(HomescreenBase):
coinjoin_authorized: bool = False,
) -> None:
self.bootscreen = bootscreen
skip = (
not bootscreen and storage_cache.homescreen_shown is self.RENDER_INDICATOR
)
skip = not bootscreen and cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_lockscreen(
label=label,
@ -102,12 +101,12 @@ class Lockscreen(HomescreenBase):
class Busyscreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.BUSYSCREEN_ON
RENDER_INDICATOR = cache_common.BUSYSCREEN_ON
def __init__(self, delay_ms: int) -> None:
from trezor import TR
skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR
skip = cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_progress_coinjoin(
title=TR.coinjoin__waiting_for_others,
@ -123,6 +122,6 @@ class Busyscreen(HomescreenBase):
# Handle timeout.
result = await super().__iter__()
assert result == trezorui2.CANCELLED
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(cache_common.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
return result

@ -57,14 +57,14 @@ class RustLayout(LayoutParentType[T]):
assert msg is None
def _paint(self) -> None:
import storage.cache as storage_cache
import storage.cache_common as cache_common
painted = self.layout.paint()
if painted:
ui.refresh()
if storage_cache.homescreen_shown is not None and painted:
storage_cache.homescreen_shown = None
if cache_common.homescreen_shown is not None and painted:
cache_common.homescreen_shown = None
if __debug__:

@ -1,8 +1,9 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as cache_common
import trezorui2
from trezor import TR, ui
from trezor.wire import context
from . import RustLayout
@ -23,9 +24,9 @@ class HomescreenBase(RustLayout):
ui.refresh()
def _first_paint(self) -> None:
if storage_cache.homescreen_shown is not self.RENDER_INDICATOR:
if cache_common.homescreen_shown is not self.RENDER_INDICATOR:
super()._first_paint()
storage_cache.homescreen_shown = self.RENDER_INDICATOR
cache_common.homescreen_shown = self.RENDER_INDICATOR
else:
self._paint()
@ -40,7 +41,7 @@ class HomescreenBase(RustLayout):
class Homescreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.HOMESCREEN_ON
RENDER_INDICATOR = cache_common.HOMESCREEN_ON
def __init__(
self,
@ -58,7 +59,7 @@ class Homescreen(HomescreenBase):
elif notification_is_error:
level = 0
skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR
skip = cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_homescreen(
label=label,
@ -84,7 +85,7 @@ class Homescreen(HomescreenBase):
class Lockscreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.LOCKSCREEN_ON
RENDER_INDICATOR = cache_common.LOCKSCREEN_ON
def __init__(
self,
@ -97,9 +98,7 @@ class Lockscreen(HomescreenBase):
if bootscreen:
self.backlight_level = ui.BacklightLevels.NORMAL
skip = (
not bootscreen and storage_cache.homescreen_shown is self.RENDER_INDICATOR
)
skip = not bootscreen and cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_lockscreen(
label=label,
@ -117,12 +116,12 @@ class Lockscreen(HomescreenBase):
class Busyscreen(HomescreenBase):
RENDER_INDICATOR = storage_cache.BUSYSCREEN_ON
RENDER_INDICATOR = cache_common.BUSYSCREEN_ON
def __init__(self, delay_ms: int) -> None:
from trezor import TR
skip = storage_cache.homescreen_shown is self.RENDER_INDICATOR
skip = cache_common.homescreen_shown is self.RENDER_INDICATOR
super().__init__(
layout=trezorui2.show_progress_coinjoin(
title=TR.coinjoin__waiting_for_others,
@ -138,6 +137,6 @@ class Busyscreen(HomescreenBase):
# Handle timeout.
result = await super().__iter__()
assert result == trezorui2.CANCELLED
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(cache_common.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
return result

@ -33,6 +33,8 @@ from typing import TYPE_CHECKING
DISABLE_ANIMATION = 0
DISABLE_ENCRYPTION: bool = False
if __debug__:
if EMULATOR:
import uos
@ -43,7 +45,13 @@ if __debug__:
LOG_MEMORY = 0
if TYPE_CHECKING:
from typing import Any, Iterator, Protocol, Sequence, TypeVar
from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Iterator,
Protocol,
Sequence,
TypeVar,
)
from trezor.protobuf import MessageType
@ -109,6 +117,7 @@ def presize_module(modname: str, size: int) -> None:
if __debug__:
from ubinascii import hexlify
def mem_dump(filename: str) -> None:
from micropython import mem_info
@ -125,6 +134,9 @@ if __debug__:
else:
mem_info(True)
def get_bytes_as_str(a):
return hexlify(a).decode("utf-8")
def ensure(cond: bool, msg: str | None = None) -> None:
if not cond:

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v3.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
@ -16,22 +16,24 @@ This module:
## Session handler
When the `wire.setup` is called the `handle_session` coroutine is scheduled. The
When the `wire.setup` is called the `handle_session` (or `handle_thp_session`) coroutine is scheduled. The
`handle_session` waits for some messages to be received on some particular interface and
reads the message's header. When the message type is known the first handler is called. This way the
`handle_session` goes through all the workflows.
"""
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor import log, loop, protobuf, utils
from trezor.wire import context, message_handler, protocol_common, thp_v3
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import (
WIRE_BUFFER,
WIRE_BUFFER_DEBUG,
failure,
find_handler,
)
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -40,12 +42,11 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, TypeVar
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -54,160 +55,52 @@ EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def _handle_single_message(
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
"""Initialize the wire stack on passed WireInterface."""
if utils.USE_THP and not is_debug_session:
loop.schedule(handle_thp_session(iface, is_debug_session))
else:
loop.schedule(handle_session(iface, is_debug_session))
async def handle_thp_session(iface: WireInterface, is_debug_session: bool = False):
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
thp_v3.set_buffer(ctx_buffer)
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
log.debug(
__name__,
"%s:%x receive: <%s>",
ctx.iface.iface_num(),
ctx.sid,
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler = find_handler(ctx.iface, msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(context.with_context(ctx, task))
else:
# For debug messages, ignore workflow processing and just await
# results of the handler.
res_msg = await task
except context.UnexpectedMessage:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
#
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
await thp_v3.thp_main_loop(iface, is_debug_session)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
utils.unimport_end(modules)
loop.clear()
return
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
async def handle_session(
iface: WireInterface, session_id: int, is_debug_session: bool = False
) -> None:
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
ctx = context.Context(iface, session_id, ctx_buffer)
next_msg: codec_v1.Message | None = None
ctx = context.CodecContext(iface, ctx_buffer)
next_msg: protocol_common.Message | None = None
if __debug__ and is_debug_session:
import apps.debug
@ -224,7 +117,7 @@ async def handle_session(
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc:
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
@ -236,10 +129,10 @@ async def handle_session(
next_msg = None
try:
do_not_restart = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session
do_not_restart = await message_handler.handle_single_message(
ctx, msg, handler_finder=find_handler
)
except context.UnexpectedMessage as unexpected:
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
@ -269,81 +162,3 @@ async def handle_session(
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
import usb
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(iface, msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter):
try:
filters.remove(filter)
except ValueError:
pass
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
class CodecError(Exception):
class CodecError(WireError):
pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)

@ -15,9 +15,12 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
from typing import TYPE_CHECKING
from storage import cache, cache_codec
from storage.cache_common import SESSIONLESS_FLAG
from trezor import log, loop, protobuf
from trezor.wire import codec_v1
from . import codec_v1
from .protocol_common import Context, Message
if TYPE_CHECKING:
from trezorio import WireInterface
@ -32,6 +35,8 @@ if TYPE_CHECKING:
overload,
)
from storage.cache_common import DataCache
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
@ -41,31 +46,35 @@ if TYPE_CHECKING:
T = TypeVar("T")
class UnexpectedMessage(Exception):
class UnexpectedMessageException(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
def __init__(self, msg: Message) -> None:
super().__init__()
self.msg = msg
class Context:
class CodecContext(Context):
"""Wire context.
Represents USB communication inside a particular session on a particular interface
Represents USB communication inside a particular session (channel) on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big"))
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
@ -95,9 +104,8 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
"%s: expect: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
@ -107,7 +115,7 @@ class Context:
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
raise UnexpectedMessageException(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
@ -115,14 +123,14 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
"%s: read: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
from . import message_handler # noqa: F401
from .message_handler import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
@ -131,9 +139,8 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x write: %s",
"%s: write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
@ -150,23 +157,19 @@ class Context:
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
c = cache_codec.get_active_session()
if c is None:
raise Exception("There is no active session")
return c
CURRENT_CONTEXT: Context | None = None
@ -273,3 +276,65 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = e
else:
send_exc = None
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get(key, default)
def cache_get_bool(key: int) -> bool: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_bool(key)
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_int(key, default)
def cache_is_set(key: int) -> bool:
cache = _get_cache_for_key(key)
return cache.is_set(key)
def cache_set(key: int, value: bytes) -> None:
cache = _get_cache_for_key(key)
cache.set(key, value)
def cache_set_bool(key: int, value: bool) -> None:
cache = _get_cache_for_key(key)
cache.set_bool(key, value)
def cache_set_int(key: int, value: int) -> None:
cache = _get_cache_for_key(key)
cache.set_int(key, value)
def cache_delete(key: int) -> None:
cache = _get_cache_for_key(key)
cache.delete(key)
def _get_cache_for_key(key) -> DataCache:
if key & SESSIONLESS_FLAG:
return cache.get_sessionless_cache()
if CURRENT_CONTEXT:
return CURRENT_CONTEXT.cache
raise Exception("No wire context")

@ -0,0 +1,255 @@
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire.context import Context, UnexpectedMessageException, with_context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor.wire.protocol_common import Message
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import Any, Callable, Container
from trezor.wire import Handler, LoadedMessageType
HandlerFinder = Callable[[Any], Handler | None]
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def handle_single_message(
ctx: Context,
msg: Message,
handler_finder: HandlerFinder,
) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
if ctx.channel_id is not None:
sid = int.from_bytes(ctx.channel_id, "big")
log.debug(
__name__,
"%s:%x receive: <%s>",
ctx.iface.iface_num(),
sid,
msg_type,
)
else:
log.debug(
__name__,
"%s:unknown_sid receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# # We need to find a handler for this message type.
# try:
# handler = find_handler(ctx.iface, msg.type)
# except Error as exc:
# # Handlers are allowed to exception out. In that case, we can skip decoding
# # and return the error.
# await ctx.write(failure(exc))
# return True
# We need to find a handler for this message type. Should not raise.
handler: Handler | None = handler_finder(msg.type)
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure.
print("handler is none")
await ctx.write(unexpected_message())
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
except UnexpectedMessageException:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
def find_handler(msg_type: int) -> Handler:
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter):
try:
filters.remove(filter)
except ValueError:
pass

@ -0,0 +1,70 @@
from typing import TYPE_CHECKING
from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Container, TypeVar, overload
from storage.cache_common import DataCache
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
T = TypeVar("T")
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
) -> None:
self.data = message_data
self.type = message_type
def to_bytes(self):
return self.type.to_bytes(2, "big") + self.data
class Context:
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ...
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
@property
def cache(self) -> DataCache: ...
class WireError(Exception):
pass

@ -0,0 +1,75 @@
from typing import TYPE_CHECKING
from trezor.wire.protocol_common import WireError
class ThpError(WireError):
pass
class ThpDecryptionError(ThpError):
pass
class ThpUnallocatedSessionError(ThpError):
def __init__(self, session_id: int):
self.session_id = session_id
if TYPE_CHECKING:
from enum import IntEnum
else:
IntEnum = object
class ThpErrorType(IntEnum):
TRANSPORT_BUSY = 1
UNALLOCATED_CHANNEL = 2
DECRYPTION_FAILED = 3
class ChannelState(IntEnum):
UNALLOCATED = 0
TH1 = 1
TH2 = 2
TP1 = 3
TP2 = 4
TP3 = 5
TP4 = 6
TC1 = 7
ENCRYPTED_TRANSPORT = 8
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
MANAGEMENT = 2
class WireInterfaceType(IntEnum):
MOCK = 0
USB = 1
BLE = 2
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

@ -0,0 +1,102 @@
from storage.cache_thp import ChannelCache
from trezor import log
from trezor.wire.thp import ThpError
def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
"""
Checks if:
- an ACK message is expected
- the received ACK message acknowledges correct sequence number (bit)
"""
if not _is_ack_expected(cache):
return False
if not _has_ack_correct_sync_bit(cache, ack_bit):
return False
return True
def _is_ack_expected(cache: ChannelCache) -> bool:
is_expected: bool = not is_sending_allowed(cache)
if __debug__ and not is_expected:
log.debug(__name__, "Received unexpected ACK message")
return is_expected
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
is_correct: bool = get_send_seq_bit(cache) == sync_bit
if __debug__ and not is_correct:
log.debug(__name__, "Received ACK message with wrong ack bit")
return is_correct
def is_sending_allowed(cache: ChannelCache) -> bool:
"""
Checks whether sending a message in the provided channel is allowed.
Note: Sending a message in a channel before receipt of ACK message for the previously
sent message (in the channel) is prohibited, as it can lead to desynchronization.
"""
return bool(cache.sync >> 7)
def get_send_seq_bit(cache: ChannelCache) -> int:
"""
Returns the sequential number (bit) of the next message to be sent
in the provided channel.
"""
return (cache.sync & 0x20) >> 5
def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
"""
Returns the (expected) sequential number (bit) of the next message
to be received in the provided channel.
"""
return (cache.sync & 0x40) >> 6
def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
"""
Set the flag whether sending a message in this channel is allowed or not.
"""
cache.sync &= 0x7F
if sending_allowed:
cache.sync |= 0x80
def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
"""
Set the expected sequential number (bit) of the next message to be received
in the provided channel
"""
if __debug__:
log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
if seq_bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")
# set second bit to "seq_bit" value
cache.sync &= 0xBF
if seq_bit:
cache.sync |= 0x40
def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
if seq_bit not in (0, 1):
raise ThpError("Unexpected send seq bit")
if __debug__:
log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# set third bit to "seq_bit" value
cache.sync &= 0xDF
if seq_bit:
cache.sync |= 0x20
def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
"""
Set the sequential bit of the "next message to be send" to the opposite value,
i.e. 1 -> 0 and 0 -> 1
"""
_set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

@ -0,0 +1,308 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
from trezor import log, loop, protobuf, utils, workflow
from trezor.wire.thp.transmission_loop import TransmissionLoop
from . import ChannelState, ThpDecryptionError, ThpError
from . import alternating_bit_protocol as ABP
from . import (
control_byte,
crypto,
interface_manager,
memory_manager,
received_message_handler,
session_manager,
)
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, PacketHeader
from .writer import (
CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if __debug__:
from ubinascii import hexlify
from . import state_to_str
if TYPE_CHECKING:
from trezorio import WireInterface
from .pairing_context import PairingContext
from .session_context import GenericSessionContext
class Channel:
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__:
log.debug(__name__, "channel initialization")
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, GenericSessionContext] = {}
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
self.transmission_loop: TransmissionLoop | None = None
self.handshake: crypto.Handshake | None = None
self._create_management_session()
def clear(self):
clear_sessions_with_channel_id(self.channel_id)
self.channel_cache.clear()
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__:
log.debug(__name__, "get_channel_state: %s", state_to_str(state))
return state
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__:
log.debug(__name__, "set_channel_state: %s", state_to_str(state))
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__:
log.debug(__name__, "set_buffer: %s", type(self.buffer))
def _create_management_session(self) -> None:
session = session_manager.create_new_management_session(self)
self.sessions[session.session_id] = session
loop.schedule(session.handle())
# CALLED BY THP_MAIN_LOOP
async def receive_packet(self, packet: utils.BufferType):
if __debug__:
log.debug(__name__, "receive_packet")
await self._handle_received_packet(packet)
if __debug__:
log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message()
await received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
self.is_cont_packet_expected = True
else:
raise ThpError(
"Read more bytes than is the expected length of the message!"
)
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if control_byte.is_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(__name__, "handle_init_packet")
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
# TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
# On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
# in "memory_manager.select_buffer", because the session id is unknown (encrypted).
# if control_byte.is_encrypted_transport(ctrl_byte):
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(__name__, "handle_cont_packet")
if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
def _decrypt_single_packet_payload(
self, payload: utils.BufferType
) -> utils.BufferType:
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload
def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None:
noise_buffer = memoryview(self.buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
tag = self.buffer[
message_length
- CHECKSUM_LENGTH
- TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
if utils.DISABLE_ENCRYPTION:
is_tag_valid = tag == crypto.DUMMY_TAG
else:
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
auth_data = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
assert key_receive is not None
assert nonce_receive is not None
assert auth_data is not None
print("Buffer before decryption:", hexlify(noise_buffer))
is_tag_valid = crypto.dec(
noise_buffer, tag, key_receive, nonce_receive, auth_data
)
print("Buffer after decryption:", hexlify(noise_buffer))
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
if __debug__:
log.debug(__name__, "Is decrypted tag valid? %s", str(is_tag_valid))
log.debug(__name__, "Received tag: %s", (hexlify(tag).decode()))
log.debug(__name__, "New nonce_receive: %i", nonce_receive + 1)
if not is_tag_valid:
raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__:
log.debug(__name__, "encrypt")
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
auth_data = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
assert key_send is not None
assert nonce_send is not None
assert auth_data is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, auth_data)
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__:
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
async def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
):
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self):
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len])
async def write_error(self, err_type: int):
msg_data = err_type.to_bytes(1, "big")
length = len(msg_data) + CHECKSUM_LENGTH
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
await write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
async def write_and_encrypt(self, payload: bytes) -> None:
payload_length = len(payload)
if not isinstance(self.buffer, bytearray):
self.buffer = bytearray(self.buffer)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
if self.write_task_spawn is not None:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
ABP.set_sending_allowed(self.channel_cache, False)
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__:
log.debug(__name__, "write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
self.transmission_loop = TransmissionLoop(self, header, payload)
await self.transmission_loop.start()
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
# Let the main loop be restarted and clear loop, if there is no other
# workflow and the state is ENCRYPTED_TRANSPORT
if self._can_clear_loop():
if __debug__:
log.debug(__name__, "clearing loop from channel")
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

@ -0,0 +1,36 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
"""
Creates a new channel for the interface `iface` with the buffer `buffer`.
"""
channel_cache = cache_thp.get_new_unauthenticated_channel(
interface_manager.encode_iface(iface)
)
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
"""
Returns all allocated channels from cache.
"""
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

@ -0,0 +1,22 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

@ -0,0 +1,47 @@
from trezor.wire.thp import ThpError
from trezor.wire.thp.thp_messages import (
ACK_MASK,
ACK_MESSAGE,
CONTINUATION_PACKET,
CONTINUATION_PACKET_MASK,
DATA_MASK,
ENCRYPTED_TRANSPORT,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
)
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise ThpError("Unexpected acknowledgement bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

@ -0,0 +1,34 @@
from trezor.crypto import elligator2, random
from trezor.crypto.curve import curve25519
from trezor.crypto.hashlib import sha512
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x50\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
def __init__(self, cpace_host_public_key: bytes) -> None:
self.host_public_key: bytes = cpace_host_public_key
self.trezor_private_key: bytes
self.trezor_public_key: bytes
self.shared_secret: bytes
def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
pregenerator = sha512(_PREFIX + code_code_entry + _PADDING).digest()[
:32
] # TODO add handshake hash
generator = elligator2.map_to_curve25519(pregenerator)
self.trezor_private_key = random.bytes(32)
if __debug__:
self.trezor_private_key = b"\xf8\xb9\xa1\x3a\xe1\x30\xb3\xe1\x5b\x8e\xd5\x80\x85\x4f\xfc\xaf\x63\x4d\x6b\x0a\x10\xfd\xe7\xba\xde\xfd\xc3\xd1\x8d\x1a\x83\xf5"
self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
self.shared_secret = curve25519.multiply(
self.trezor_private_key, self.host_public_key
)

@ -0,0 +1,206 @@
from micropython import const
from trezorcrypto import aesgcm, bip32, curve25519, hmac
from storage import device
from trezor import utils
from trezor.crypto.hashlib import sha256
from trezor.wire.thp import ThpDecryptionError
# The HARDENED flag is taken from apps.common.paths
# It is not imported to save on resources
HARDENED = const(0x8000_0000)
PUBKEY_LENGTH = const(32)
if utils.DISABLE_ENCRYPTION:
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
if __debug__:
from ubinascii import hexlify
def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
"""
Encrypts the provided `buffer` with AES-GCM (in place).
Returns a 16-byte long encryption tag.
"""
print("ENCRYPT-------------> used key: ", hexlify(key))
print("ENCRYPT-------------> used nonce:", nonce)
iv = _get_iv_from_nonce(nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.encrypt_in_place(buffer)
return aes_ctx.finish()
def dec(
buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
) -> bool:
"""
Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
the tag computed in decryption, otherwise it returns `False`.
"""
iv = _get_iv_from_nonce(nonce)
print("DECRYPT-------------> used key: ", hexlify(key))
print("DECRYPT-------------> used nonce:", nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.decrypt_in_place(buffer)
computed_tag = aes_ctx.finish()
return computed_tag == tag
class BusyDecoder:
def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
iv = _get_iv_from_nonce(nonce)
self.aes_ctx = aesgcm(key, iv)
self.aes_ctx.auth(auth_data)
def decrypt_part(self, part: utils.BufferType) -> None:
self.aes_ctx.decrypt_in_place(part)
def finish_and_check_tag(self, tag: bytes) -> bool:
computed_tag = self.aes_ctx.finish()
return computed_tag == tag
PROTOCOL_NAME = bytes("Noise_XX_25519_AESGCM_SHA256", "ascii")
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
class Handshake:
"""
`Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
The following values should be saved for future use before disposing of this object:
- `h` (handshake hash, can be used to bind other values to the channel)
- `key_receive` (key for decrypting incoming communication)
- `key_send` (key for encrypting outgoing communication)
"""
def __init__(self) -> None:
self.trezor_ephemeral_privkey: bytes
self.ck: bytes
self.k: bytes
self.h: bytes
self.key_receive: bytes
self.key_send: bytes
def handle_th1_crypto(
self,
device_properties: bytes,
host_ephemeral_pubkey: bytes,
) -> tuple[bytes, bytes, bytes]:
trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
self.trezor_ephemeral_privkey = curve25519.generate_secret()
trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
point = curve25519.multiply(
self.trezor_ephemeral_privkey, host_ephemeral_pubkey
)
self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
aes_ctx = aesgcm(self.k, IV_1)
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
print("TH1-ENCRYPT---------> used key: ", hexlify(self.k))
print("TH1-ENCRYPT---------> used nonce:", 0)
aes_ctx.auth(self.h)
tag_to_encrypted_key = aes_ctx.finish()
encrypted_trezor_static_pubkey = (
encrypted_trezor_static_pubkey + tag_to_encrypted_key
)
self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
tag = aes_ctx.finish()
self.h = _hash_of_two(self.h, tag)
return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
def handle_th2_crypto(
self,
encrypted_host_static_pubkey: utils.BufferType,
encrypted_payload: utils.BufferType,
):
aes_ctx = aesgcm(self.k, IV_2)
# The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
# However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
# authentication of the gcm tag.
aes_ctx.auth(self.h) # Authenticate with the previous value of `h`
self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value
aes_ctx.decrypt_in_place(
memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
)
print("TH2-DECRYPT---------> used key: ", hexlify(self.k))
print("TH2-DECRYPT---------> used nonce:", 1)
host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
tag = aes_ctx.finish()
if tag != encrypted_host_static_pubkey[-16:]:
raise ThpDecryptionError()
self.ck, self.k = _hkdf(
self.ck,
curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
print("TH2-DECRYPT---------> used key: ", hexlify(self.k))
print("TH2-DECRYPT---------> used nonce:", 0)
tag = aes_ctx.finish()
if tag != encrypted_payload[-16:]:
raise ThpDecryptionError()
self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
self.key_receive, self.key_send = _hkdf(self.ck, b"")
print("TREZOR_KEY_RECEIVE:", hexlify(self.key_receive))
print("TREZOR_KEY_SEND: ", hexlify(self.key_send))
def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
aes_ctx = aesgcm(self.key_send, IV_1)
encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
tag = aes_ctx.finish()
return encrypted_trezor_state + tag
def _derive_static_key_pair() -> tuple[bytes, bytes]:
node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
node = bip32.from_seed(device.get_device_secret(), "curve25519")
node.derive(node_int)
trezor_static_privkey = node.private_key()
trezor_static_pubkey = node.public_key()[1:33]
# Note: the first byte (\x01) of the public key is removed, as it
# only indicates the type of the elliptic curve used
return trezor_static_privkey, trezor_static_pubkey
def get_trezor_static_pubkey() -> bytes:
_, pubkey = _derive_static_key_pair()
return pubkey
def _hkdf(chaining_key, input: bytes):
temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
ctx = sha256(part_1)
ctx.update(part_2)
return ctx.digest()
def _get_iv_from_nonce(nonce: int) -> bytes:
utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")

@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.errors import UnexpectedMessage
if TYPE_CHECKING:
from typing import Any, Callable, Coroutine
from trezor.messages import Features, GetFeatures
def find_management_session_message_handler(
msg_type: int,
) -> Callable[[Any], Coroutine[Any, Any, protobuf.MessageType]]:
if msg_type is MessageType.ThpCreateNewSession:
from apps.thp.create_session import create_new_session
return create_new_session
if msg_type is MessageType.GetFeatures:
return handle_GetFeatures
if __debug__:
if msg_type is MessageType.LoadDevice:
from apps.debug.load_device import load_device
return load_device
raise UnexpectedMessage("There is no handler available for this message")
async def handle_GetFeatures(msg: GetFeatures) -> Features:
from apps.base import get_features
return get_features()

@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
import usb
_MOCK_INTERFACE_HID = b"\x00"
_WIRE_INTERFACE_USB = b"\x01"
if TYPE_CHECKING:
from trezorio import WireInterface
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")

@ -0,0 +1,124 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from . import ChannelState, ThpError
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_message(
payload_length, channel_buffer
)
return buffer
except Exception as e:
if __debug__:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def get_write_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType
) -> utils.BufferType:
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
return bytearray(required_min_size)
return buffer
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__:
log.debug(
__name__,
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(PACKET_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:length]

@ -0,0 +1,250 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
import trezorui2
from trezor import loop, protobuf, workflow
from trezor.ui.layouts.tt import RustLayout
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import Container
from .channel import Channel
from .cpace import Cpace
pass
if __debug__:
from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.display_code_entry: bool = False
self.display_qr_code: bool = False
self.display_nfc_unidirectional: bool = False
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc_unidirectional: bytes | None = None
def get_display_layout(self) -> RustLayout:
return RustLayout(
trezorui2.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=self._get_code_qr_code_str(),
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + self._get_code_code_entry_str(),
path="",
xpubs=[],
)
)
def _get_code_code_entry_str(self) -> str:
if self.display_code_entry and self.code_code_entry is not None:
code_str = str(self.code_code_entry)
print("code_code_entry:", code_str)
return code_str[:3] + " " + code_str[3:]
return "NOT ALLOWED"
def _get_code_qr_code_str(self) -> str:
if self.display_qr_code and self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
print("code_qr_code_hexlified:", code_str)
return code_str
return "QR CODE IS NOT SUPPOSED TO BE DISPLAYED!!!!"
class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.chan()
self.secret: bytes = (
b"\xBA\xDA\x55\xBA\xDA\x55\xBA\xDA\x55\xBA\xDA\x55\xDE\xAD\xBE\xEF" # TODO generate randomly
)
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
log.debug(__name__, "handle - start")
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: Message | None = None
while True:
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: Message = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(message_handler.failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await handle_pairing_request_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None:
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
async def handle_pairing_request_message(
pairing_ctx: PairingContext,
msg: protocol_common.Message,
use_workflow: bool,
) -> protocol_common.Message | None:
res_msg: protobuf.MessageType | None = None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
else:
# For debug messages, ignore workflow processing and just await
# results of the handler.
res_msg = await task
except UnexpectedMessageException as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
return exc.msg
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await pairing_ctx.write(res_msg)
return None

@ -0,0 +1,382 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, loop, utils
from trezor.enums import FailureType
from trezor.messages import Failure
from ..errors import DataError
from ..protocol_common import Message
from . import (
ChannelState,
SessionState,
ThpDecryptionError,
ThpError,
ThpErrorType,
ThpUnallocatedSessionError,
)
from . import alternating_bit_protocol as ABP
from . import checksum, control_byte, is_channel_state_pairing, thp_messages
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH, Handshake
from .thp_messages import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
PacketHeader,
)
from .writer import (
INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from .channel import Channel
if __debug__:
from ubinascii import hexlify
from . import state_to_str
async def handle_received_message(
ctx: Channel, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__:
log.debug(__name__, "handle_received_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_HEADER_LENGTH
_check_checksum(message_length, message_buffer)
# Synchronization process
seq_bit = (ctrl_byte & 0x10) >> 4
ack_bit = (ctrl_byte & 0x08) >> 3
if __debug__:
log.debug(
__name__,
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
seq_bit,
ack_bit,
)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, ack_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected sequential bit
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
if __debug__:
log.debug(__name__, "Received message with an unexpected sequential bit")
await _send_ack(ctx, ack_bit=seq_bit)
raise ThpError("Received message with an unexpected sequential bit")
# 3: Send ACK in response
await _send_ack(ctx, ack_bit=seq_bit)
ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
try:
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte
)
except ThpUnallocatedSessionError as e:
error_message = Failure(code=FailureType.ThpUnallocatedSession)
await ctx.write(error_message, e.session_id)
print(e)
except ThpDecryptionError as e:
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
# TODO ctx.write_error, i.e., ctx.send_channelClosed_error
# ctx.clear
print(e)
if __debug__:
log.debug(__name__, "handle_received_message - end")
async def _send_ack(ctx: Channel, ack_bit: int) -> None:
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, ack_bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
):
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_ack(ctx: Channel, ack_bit: int):
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct ack bit")
if ctx.transmission_loop is not None:
ctx.transmission_loop.stop_immediately()
if __debug__:
log.debug(__name__, "Stopped transmission loop")
ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in termination of this function - any code after
# this await might not be executed
async def _handle_message_to_app_or_channel(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
state = ctx.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
return
if state is ChannelState.TH1:
await _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
return
if state is ChannelState.TH2:
await _handle_state_TH2(ctx, message_length, ctrl_byte)
return
if is_channel_state_pairing(state):
await _handle_pairing(ctx, message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
ctx.handshake = Handshake()
host_ephemeral_pubkey = bytearray(
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
)
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
ctx.handshake.handle_th1_crypto(
thp_messages.get_encoded_device_properties(), host_ephemeral_pubkey
)
)
if __debug__:
log.debug(
__name__,
"trezor ephemeral pubkey: %s",
hexlify(trezor_ephemeral_pubkey).decode(),
)
log.debug(
__name__,
"trezor masked static pubkey: %s",
hexlify(encrypted_trezor_static_pubkey).decode(),
)
log.debug(__name__, "tag: %s", hexlify(tag))
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
# send handshake init response message
await ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
from apps.thp.credential_manager import validate_credential
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
if ctx.handshake is None:
raise Exception("Handshake object is not prepared. Retry handshake.")
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
ctx.handshake.handle_th2_crypto(
host_encrypted_static_pubkey, handshake_completion_request_noise_payload
)
ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
noise_payload = thp_messages.decode_message(
ctx.buffer[
INIT_HEADER_LENGTH
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
if i not in ctx.selected_pairing_methods:
ctx.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# key is decoded in handshake._handle_th2_crypto
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
paired: bool = False
if noise_payload.host_pairing_credential is not None:
try: # TODO change try-except for something better
paired = validate_credential(
noise_payload.host_pairing_credential,
host_static_pubkey,
)
except DataError as e:
if __debug__:
log.exception(__name__, e)
pass
trezor_state = thp_messages.TREZOR_STATE_UNPAIRED
if paired:
trezor_state = thp_messages.TREZOR_STATE_PAIRED
# send hanshake completion response
await ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
ctx.handshake.get_handshake_completion_response(trezor_state),
)
ctx.handshake = None
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
)
if session_id not in ctx.sessions:
raise ThpUnallocatedSessionError(session_id)
session_state = ctx.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
raise ThpUnallocatedSessionError(session_id)
ctx.sessions[session_id].incoming_message.publish(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: Channel, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True

@ -0,0 +1,223 @@
from typing import TYPE_CHECKING
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure, find_handler
from ..protocol_common import Context, Message
from . import SessionState
if TYPE_CHECKING:
from typing import Any, Awaitable, Container, TypeVar, overload
from storage.cache_common import DataCache
from ..message_handler import HandlerFinder
from .channel import Channel
pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
if __debug__:
from trezor.utils import get_bytes_as_str
class GenericSessionContext(Context):
def __init__(self, channel: Channel, session_id: int) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel
self.session_id: int = session_id
self.incoming_message = loop.chan()
self.handler_finder: HandlerFinder = find_handler
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
self._handle_debug(is_debug_session)
take = self.incoming_message.take()
next_message: Message | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
# TODO modules = utils.unimport_begin()
while True:
try:
if await self._handle_message(take, next_message, is_debug_session):
return
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_message = unexpected.msg
continue
except Exception as exc:
# Log and try again.
if __debug__:
log.exception(__name__, exc)
def _handle_debug(self, is_debug_session: bool) -> None:
log.debug(
__name__,
"handle - start (channel_id (bytes): %s, session_id: %d)",
get_bytes_as_str(self.channel_id),
self.session_id,
)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
async def _handle_message(
self,
take: Awaitable[Any],
next_message: Message | None,
is_debug_session: bool,
) -> bool:
try:
message = await self._get_message(take, next_message)
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
try:
await message_handler.handle_single_message(
self,
message,
self.handler_finder,
)
except UnexpectedMessageException:
raise
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None and message.type not in AVOID_RESTARTING_FOR:
# Shut down the loop if there is no next message waiting.
return _EXIT_LOOP # pylint: disable=lost-exception
return _REPEAT_LOOP # pylint: disable=lost-exception
async def _get_message(
self, take: Awaitable[Any], next_message: Message | None
) -> Message:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
message: Message = await take
else:
# Process the message from previous run.
message = next_message
next_message = None
return message
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
def get_session_state(self) -> SessionState: ...
class ManagementSessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx, MANAGEMENT_SESSION_ID)
from trezor.wire.thp.handler_provider import (
find_management_session_message_handler,
)
self.handler_finder = find_management_session_message_handler
def get_session_state(self) -> SessionState:
return SessionState.MANAGEMENT
class SessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
session_id = int.from_bytes(session_cache.session_id, "big")
super().__init__(channel_ctx, session_id)
self.session_cache = session_cache
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
state = int.from_bytes(self.session_cache.state, "big")
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
return self.session_cache
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(
self, key: int, default: T | None = None
) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.session_cache.fields))
if self.session_cache.data[key][0] != 1:
return default
return bytes(self.session_cache.data[key][1:])
def cache_is_set(self, key: int) -> bool:
return self.session_cache.is_set(key)
def cache_set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.session_cache.fields))
utils.ensure(len(value) <= self.session_cache.fields[key])
self.session_cache.data[key][0] = 1
self.session_cache.data[key][1:] = value

@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import loop
from .session_context import (
GenericSessionContext,
ManagementSessionContext,
SessionContext,
)
if __debug__:
from trezor import log
if TYPE_CHECKING:
from .channel import Channel
def create_new_session(channel_ctx: Channel) -> SessionContext:
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def create_new_management_session(
channel_ctx: Channel,
) -> ManagementSessionContext:
return ManagementSessionContext(channel_ctx)
def load_cached_sessions(
channel_ctx: Channel,
) -> dict[int, GenericSessionContext]:
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, GenericSessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel_ctx.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel_ctx, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -0,0 +1,120 @@
import ustruct
from micropython import const
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf, utils
from trezor.enums import ThpPairingMethod
from trezor.messages import ThpDeviceProperties
from .. import message_handler
CODEC_V1 = const(0x3F)
CONTINUATION_PACKET = const(0x80)
HANDSHAKE_INIT_REQ = const(0x00)
HANDSHAKE_INIT_RES = const(0x01)
HANDSHAKE_COMP_REQ = const(0x02)
HANDSHAKE_COMP_RES = const(0x03)
ENCRYPTED_TRANSPORT = const(0x04)
CONTINUATION_PACKET_MASK = const(0x80)
ACK_MASK = const(0xF7)
DATA_MASK = const(0xE7)
ACK_MESSAGE = const(0x20)
_ERROR = const(0x42)
CHANNEL_ALLOCATION_REQ = const(0x40)
_CHANNEL_ALLOCATION_RES = const(0x41)
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
if __debug__:
from trezor import log
class PacketHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
def pack_to_init_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
@classmethod
def get_error_header(cls, cid, length):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length):
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
_ENCODED_DEVICE_PROPERTIES: bytes | None = None
_ENABLED_PAIRING_METHODS = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.QrCode,
ThpPairingMethod.NFC_Unidirectional,
]
def _get_device_properties() -> ThpDeviceProperties:
# TODO define model variants
return ThpDeviceProperties(
pairing_methods=_ENABLED_PAIRING_METHODS,
internal_model=utils.INTERNAL_MODEL,
model_variant=0,
bootloader_mode=False,
protocol_version=3,
)
def get_encoded_device_properties() -> bytes:
global _ENCODED_DEVICE_PROPERTIES
if _ENCODED_DEVICE_PROPERTIES is None:
props = _get_device_properties()
length = protobuf.encoded_length(props)
_ENCODED_DEVICE_PROPERTIES = bytearray(length)
protobuf.encode(_ENCODED_DEVICE_PROPERTIES, props)
return _ENCODED_DEVICE_PROPERTIES
def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes:
props_msg = get_encoded_device_properties()
return nonce + new_cid + props_msg
def get_codec_v1_error_message() -> bytes:
return b"?##" # TODO add unsupported codec_v1 error message
def decode_message(
buffer: bytes, msg_type: int, message_name: str | None = None
) -> protobuf.MessageType:
if __debug__:
log.debug(__name__, "decode message")
if message_name is not None:
expected_type = protobuf.type_for_name(message_name)
else:
expected_type = protobuf.type_for_wire(msg_type)
x = message_handler.wrap_protobuf_load(buffer, expected_type)
return x

@ -0,0 +1,59 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import loop
from trezor.wire.thp.thp_messages import PacketHeader
from trezor.wire.thp.writer import write_payload_to_wire_and_add_checksum
if TYPE_CHECKING:
from trezor.wire.thp.channel import Channel
MAX_RETRANSMISSION_COUNT = const(50)
MIN_RETRANSMISSION_COUNT = const(2)
class TransmissionLoop:
def __init__(
self, channel: Channel, header: PacketHeader, transport_payload: bytes
) -> None:
self.channel: Channel = channel
self.header: PacketHeader = header
self.transport_payload: bytes = transport_payload
self.wait_task: loop.spawn | None = None
self.min_retransmisson_count_achieved: bool = False
async def start(self):
self.min_retransmisson_count_achieved = False
for i in range(MAX_RETRANSMISSION_COUNT):
if i >= MIN_RETRANSMISSION_COUNT:
self.min_retransmisson_count_achieved = True
await write_payload_to_wire_and_add_checksum(
self.channel.iface, self.header, self.transport_payload
)
self.wait_task = loop.spawn(self._wait(i))
try:
await self.wait_task
except loop.TaskClosed:
self.wait_task = None
break
def stop_immediately(self):
if self.wait_task is not None:
self.wait_task.close()
self.wait_task = None
async def stop_after_min_retransmission(self):
while not self.min_retransmisson_count_achieved and self.wait_task is not None:
await self._short_wait()
self.stop_immediately()
async def _wait(self, counter: int = 0) -> None:
timeout_ms = round(10200 - 1010000 / (counter + 100))
await loop.sleep(timeout_ms)
async def _short_wait(self):
loop.wait(50)
def __del__(self):
self.stop_immediately()

@ -0,0 +1,82 @@
from micropython import const
from trezorcrypto import crc
from typing import TYPE_CHECKING
from trezor import io, log, loop, utils
from trezor.wire.thp.thp_messages import PacketHeader
INIT_HEADER_LENGTH = const(5)
CONT_HEADER_LENGTH = const(3)
PACKET_LENGTH = const(64)
CHECKSUM_LENGTH = const(4)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Sequence
async def write_payload_to_wire_and_add_checksum(
iface: WireInterface, header: PacketHeader, transport_payload: bytes
):
header_checksum: int = crc.crc32(header.to_bytes())
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
CHECKSUM_LENGTH, "big"
)
data = (transport_payload, checksum)
await write_payloads_to_wire(iface, header, data)
async def write_payloads_to_wire(
iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
):
n_of_data = len(data)
total_length = sum(len(item) for item in data)
current_data_idx = 0
current_data_offset = 0
packet = bytearray(PACKET_LENGTH)
header.pack_to_init_buffer(packet)
packet_offset: int = INIT_HEADER_LENGTH
packet_number = 0
nwritten = 0
while nwritten < total_length:
if packet_number == 1:
header.pack_to_cont_buffer(packet)
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
packet[:] = bytearray(PACKET_LENGTH)
header.pack_to_cont_buffer(packet)
while True:
n = utils.memcpy(
packet, packet_offset, data[current_data_idx], current_data_offset
)
packet_offset += n
current_data_offset += n
nwritten += n
if packet_offset < PACKET_LENGTH:
current_data_idx += 1
current_data_offset = 0
if current_data_idx >= n_of_data:
break
elif packet_offset == PACKET_LENGTH:
break
else:
raise Exception("Should not happen!!!")
packet_number += 1
packet_offset = CONT_HEADER_LENGTH
await write_packet_to_wire(iface, packet)
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
while True:
await loop.wait(iface.iface_num() | io.POLL_WRITE)
if __debug__:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
n = iface.write(packet)
if n == len(packet):
return

@ -0,0 +1,159 @@
from ubinascii import hexlify
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils
from trezor.wire.thp import writer
from .thp import (
ChannelState,
ThpError,
ThpErrorType,
channel_manager,
checksum,
session_manager,
thp_messages,
)
from .thp.channel import Channel
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, PacketHeader
from .thp.writer import (
MAX_PAYLOAD_LEN,
PACKET_LENGTH,
INIT_HEADER_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12)
_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {}
def set_buffer(buffer):
global _BUFFER
_BUFFER = buffer
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
global _CHANNELS
global _BUFFER
_CHANNELS = channel_manager.load_cached_channels(_BUFFER)
for ch in _CHANNELS.values():
ch.sessions.update(session_manager.load_cached_sessions(ch))
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
try:
if __debug__:
log.debug(__name__, "thp_main_loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
await _handle_codec_v1(iface, packet)
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in _CHANNELS:
await _handle_allocated(iface, cid, packet)
else:
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
log.exception(__name__, e)
async def _handle_codec_v1(iface: WireInterface, packet):
# If the received packet is not initial codec_v1 packet, do not send error message
if not packet[1:3] == b"##":
return
error_message = thp_messages.get_codec_v1_error_message()
await writer.write_packet_to_wire(iface, error_message)
async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None:
global _BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on the broadcast channel")
length, nonce = ustruct.unpack(">H8s", packet[3:])
payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
if not checksum.is_valid(
payload[-4:],
packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
):
raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel
response_data = thp_messages.get_channel_allocation_response(
nonce, new_channel.channel_id
)
response_header = PacketHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", cid)
await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = _CHANNELS[cid]
if channel is None:
_handle_unallocated(iface, cid)
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
async def _handle_unallocated(iface, cid) -> None:
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
await write_payload_to_wire_and_add_checksum(iface, header, data)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(PACKET_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]

@ -1,7 +1,7 @@
import utime
from typing import TYPE_CHECKING
import storage.cache
import storage.cache_common as cache_common
from trezor import log, loop
from trezor.enums import MessageType
@ -152,7 +152,7 @@ def close_others() -> None:
if not task.is_running():
task.close()
storage.cache.homescreen_shown = None
cache_common.homescreen_shown = None
# if tasks were running, closing the last of them will run start_default
@ -210,11 +210,11 @@ class IdleTimer:
time and saves it to storage.cache. This is done to avoid losing an
active timer when workflow restart happens and tasks are lost.
"""
if _restore_from_cache and storage.cache.autolock_last_touch is not None:
now = storage.cache.autolock_last_touch
if _restore_from_cache and cache_common.autolock_last_touch is not None:
now = cache_common.autolock_last_touch
else:
now = utime.ticks_ms()
storage.cache.autolock_last_touch = now
cache_common.autolock_last_touch = now
for callback, task in self.tasks.items():
timeout_us = self.timeouts[callback]

@ -0,0 +1,42 @@
#!/usr/bin/env bash
declare -a results
declare -i passed=0 failed=0 exit_code=0
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
print_summary() {
echo
echo 'Summary:'
echo '-------------------'
printf '%b\n' "${results[@]}"
if [ $exit_code == 0 ]; then
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
else
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
fi
}
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
cd $(dirname $0)
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
declare -i num_of_tests=${#tests[@]}
for test_case in ${tests[@]}; do
echo ${MICROPYTHON}
echo ${test_case}
echo
if $MICROPYTHON $test_case; then
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
((passed++))
else
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
((failed++))
exit_code=1
fi
done
print_summary
exit $exit_code

@ -1,20 +1,31 @@
from common import * # isort:skip
from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage
from storage import cache
from trezor.messages import EndSession, Initialize
from storage import cache, cache_codec, cache_thp
from storage.cache_common import InvalidSessionError
from trezor import utils
from trezor.messages import Initialize
from trezor.messages import EndSession
from apps.base import handle_EndSession, handle_Initialize
KEY = 0
if utils.USE_THP:
_PROTOCOL_CACHE = cache_thp
else:
_PROTOCOL_CACHE = cache_codec
# Function moved from cache.py, as it was not used there
def is_session_started() -> bool:
return cache._active_session_idx is not None
return _PROTOCOL_CACHE.get_active_session() is not None
class TestStorageCache(unittest.TestCase):
class TestStorageCache(
unittest.TestCase
): # noqa: F405 # pyright: ignore[reportUndefinedVariable]
def setUp(self):
cache.clear_all()
@ -25,9 +36,9 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def test_end_session(self):
@ -36,7 +47,7 @@ class TestStorageCache(unittest.TestCase):
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertRaises(InvalidSessionError, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
@ -63,7 +74,7 @@ class TestStorageCache(unittest.TestCase):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT):
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
@ -83,7 +94,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def test_get_set_int(self):
@ -101,7 +112,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get_int(KEY), 1234)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get_int(KEY)
def test_delete(self):
@ -186,6 +197,9 @@ class TestStorageCache(unittest.TestCase):
@mock_storage
def test_Initialize(self):
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
return
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
@ -210,7 +224,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one
@ -218,13 +232,14 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
def test_EndSession(self):
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
cache.start_session()
self.assertRaises(InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertRaises(InvalidSessionError, cache.get, KEY)
if __name__ == "__main__":

@ -0,0 +1,95 @@
from common import *
from trezor import utils
if utils.USE_THP:
from trezor.wire.thp import checksum
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolChecksum(unittest.TestCase):
vectors_correct = [
(
b"",
b"\x00\x00\x00\x00",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAD",
),
(
bytes("a", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("123456789", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
(
b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
b"\x9B\xD3\x66\xAE",
),
(
b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
b"\x6B\xA4\xEC\x92",
),
]
vectors_incorrect = [
(
b"",
b"\x00\x00\x00\x00\x00",
),
(
b"",
b"",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAE",
),
(
bytes("A", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc ", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("1234567890", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"1234567890123456789012345678901234567890123456789012345678901234567890123456789",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
]
def test_computation(self):
for data, chksum in self.vectors_correct:
self.assertEqual(checksum.compute(data), chksum)
def test_validation_correct(self):
for data, chksum in self.vectors_correct:
self.assertTrue(checksum.is_valid(chksum, data))
def test_validation_incorrect(self):
for data, chksum in self.vectors_incorrect:
self.assertFalse(checksum.is_valid(chksum, data))
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,63 @@
from common import *
from trezor import config, utils
from trezor import log
if utils.USE_THP:
from apps.thp import credential_manager
from trezor.messages import ThpCredentialMetadata
def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes:
metadata = ThpCredentialMetadata(host_name=host_name)
return credential_manager.issue_credential(host_static_pubkey, metadata)
def _dummy_log(name: str, msg: str, *args):
pass
log.debug = _dummy_log
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCredentialManager(unittest.TestCase):
def setUp(self):
config.init()
config.wipe()
def test_derive_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertEqual(len(key1), 32)
self.assertEqual(key1, key2)
def test_invalidate_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
credential_manager.invalidate_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertNotEqual(key1, key2)
def test_credentials(self):
DUMMY_KEY_1 = b"\x00\x00"
DUMMY_KEY_2 = b"\xff\xff"
HOST_NAME_1 = "host_name"
HOST_NAME_2 = "different host_name"
cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertEqual(cred_1, cred_2)
cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_3)
self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2))
credential_manager.invalidate_cred_auth_key()
cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_4)
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1))
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,150 @@
from common import *
import storage
from trezor import utils
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
from trezorcrypto import aesgcm, curve25519
if utils.USE_THP:
from trezor.wire.thp import crypto
def get_dummy_device_secret():
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCrypto(unittest.TestCase):
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
handshake = Handshake()
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
vectors_enc = [
(
key_1,
0,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"e2c9dd152fbee5821ea7",
b"10625812de81b14a46b9f1e5100a6d0c",
),
(
key_1,
1,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"79811619ddb07c2b99f8",
b"71c6b872cdc499a7e9a3c7441f053214",
),
(
key_1,
369,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"c1200f8a7ae9a6d32cef0fff878d55c2",
),
(
key_1,
369,
b"\x55\x64\x73\x82\x91",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"693ac160cd93a20f7fc255f049d808d0",
),
]
# 0:chaining key, 1:input, 2:output_1, 3:output:2
vectors_hkdf = [
(
crypto.PROTOCOL_NAME,
b"\x01\x02",
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
),
(
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
),
]
vectors_iv = [
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
]
def setUp(self):
utils.DISABLE_ENCRYPTION = False
def test_encryption(self):
for v in self.vectors_enc:
buffer = bytearray(v[3])
tag = crypto.enc(buffer, v[0], v[1], v[2])
self.assertEqual(hexlify(buffer), v[4])
self.assertEqual(hexlify(tag), v[5])
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
self.assertEqual(buffer, v[3])
def test_hkdf(self):
for v in self.vectors_hkdf:
ck, k = crypto._hkdf(v[0], v[1])
self.assertEqual(hexlify(ck), v[2])
self.assertEqual(hexlify(k), v[3])
def test_iv_from_nonce(self):
for v in self.vectors_iv:
x = v[0]
y = x.to_bytes(8, "big")
iv = crypto._get_iv_from_nonce(v[0])
self.assertEqual(iv, v[1])
with self.assertRaises(AssertionError) as e:
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
def test_incorrect_vectors(self):
pass
def test_th1_crypto(self):
storage.device.get_device_secret = get_dummy_device_secret
handshake = self.handshake
host_ephemeral_privkey = curve25519.generate_secret()
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
def test_th2_crypto(self):
handshake = self.handshake
host_static_privkey = curve25519.generate_secret()
host_static_pubkey = curve25519.publickey(host_static_privkey)
aes_ctx = aesgcm(handshake.k, IV_2)
aes_ctx.auth(handshake.h)
encrypted_host_static_pubkey = bytearray(
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
)
# Code to encrypt Host's noise encrypted payload correctly:
protomsg = bytearray(b"\x10\x02\x10\x03")
temp_k = handshake.k
temp_h = handshake.h
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
_, temp_k = crypto._hkdf(
handshake.ck,
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(temp_k, IV_1)
aes_ctx.encrypt_in_place(protomsg)
aes_ctx.auth(temp_h)
tag = aes_ctx.finish()
encrypted_payload = bytearray(protomsg + tag)
# end of encrypted payload generation
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,260 @@
from common import *
from apps.thp import pairing
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from trezor.enums import ThpPairingMethod, MessageType
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.protocol_common import Message
from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCpaceHost,
ThpCodeEntryTag,
ThpCredentialRequest,
ThpEndRequest,
ThpStartPairingRequest,
)
from trezor import io, config, log, protobuf
from trezor.loop import wait
from trezor.wire import thp_v3
from trezor.wire.thp import interface_manager
from storage import cache_thp
from trezor.wire.thp import ChannelState
from trezor.crypto import elligator2
from trezor.crypto.curve import curve25519
# Disable log.debug for the test
log.debug = lambda name, msg, *args: None
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def getBytes(a):
return hexlify(a).decode("utf-8")
def get_dummy_key() -> bytes:
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_v3.set_buffer(buffer)
interface_manager.decode_iface = dummy_decode_iface
def test_simple(self):
self.assertTrue(True)
def test_channel_allocation(self):
cid_req = (
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
expected_response = "41ffff0020001122334455667712340a04543254311000180020032802280328048ed892b3000000000000000000000000000000000000000000000000000000"
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_v3._CHANNELS), 0)
self.assertFalse(test_counter in thp_v3._CHANNELS)
gen = thp_v3.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req)
gen.send(None)
self.assertEqual(
getBytes(self.interface.data[-1]),
expected_response,
)
self.assertTrue(test_counter in thp_v3._CHANNELS)
self.assertEqual(len(thp_v3._CHANNELS), 1)
gen.send(cid_req)
gen.send(None)
gen.send(cid_req)
gen.send(None)
def test_channel_default_state_is_TH1(self):
self.assertEqual(thp_v3._CHANNELS[4660].get_channel_state(), ChannelState.TH1)
def test_channel_errors(self):
gen = thp_v3.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
message_to_channel_789a = (
b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
gen.send(message_to_channel_789a)
gen.send(None)
unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
getBytes(self.interface.data[-1]),
unallocated_chanel_error_on_channel_789a,
)
config.init()
config.wipe()
channel = thp_v3._CHANNELS[4661]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
message_with_invalid_tag = b"\x04\x12\x35\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\xe1\xfc\xc6\xe0"
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
gen.send(message_with_invalid_tag)
gen.send(None)
ack_on_received_message = "2012350004d83ea46f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
getBytes(self.interface.data[-1]),
ack_on_received_message,
)
gen.send(None)
decryption_failed_error_on_channel_1235 = "421235000503caf9634a000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
getBytes(self.interface.data[-1]),
decryption_failed_error_on_channel_1235,
)
channel = thp_v3._CHANNELS[4662]
channel.iface = self.interface
channel.set_channel_state(ChannelState.TH2)
message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
# gen.send(message_with_invalid_tag)
# gen.send(None)
# gen.send(None)
for i in self.interface.data:
print(hexlify(i))
def test_skip_pairing(self):
config.init()
config.wipe()
channel = thp_v3._CHANNELS[4660]
channel.selected_pairing_methods = [
ThpPairingMethod.NoMethod,
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
# Teardown: set back initial channel state value
channel.set_channel_state(ChannelState.TH1)
def test_pairing(self):
config.init()
config.wipe()
channel = thp_v3._CHANNELS[4660]
channel.selected_pairing_methods = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
with self.assertRaises(UnexpectedMessage) as e:
pairing.handle_pairing_request(pairing_ctx, request_message)
print(e.value.message)
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
gen.send(None)
async def _dummy(ctx: PairingContext, expected_types):
return await ctx.read([1018, 1024])
pairing.show_display_data = _dummy
msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
protobuf.encode(buffer, msg_code_entry)
code_entry_challenge = Message(MessageType.ThpCodeEntryChallenge, buffer)
gen.send(code_entry_challenge)
# tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
# tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
generator_host = elligator2.map_to_curve25519(pregenerator_host)
cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
cpace_host_public_key: bytes = curve25519.multiply(
cpace_host_private_key, generator_host
)
msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
# msg = ThpQrCodeTag(tag=tag_qrc)
# msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(MessageType.ThpCodeEntryCpaceHost, buffer)
gen.send(user_message)
tag_ent = b"\xf5\x20\xee\xae\xb8\xa9\x65\x3e\x77\x89\x8f\x81\x8d\x03\x4d\xaa\x93\x79\xc3\xe4\x89\x3c\xb8\x31\x42\xdc\x01\x57\x2d\x5d\x11\xb5"
msg = ThpCodeEntryTag(tag=tag_ent)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(MessageType.ThpCodeEntryTag, buffer)
gen.send(user_message)
host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
credential_request = Message(MessageType.ThpCredentialRequest, buffer)
gen.send(credential_request)
msg = ThpEndRequest()
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
end_request = Message(1012, buffer)
with self.assertRaises(StopIteration) as e:
gen.send(end_request)
print("response message:", e.value.value.MESSAGE_NAME)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,167 @@
from common import *
from trezor import utils
if utils.USE_THP:
from trezor.wire.thp import writer
from trezor.wire.thp.thp_messages import PacketHeader, ENCRYPTED_TRANSPORT
if __debug__:
# Disable log.debug for the test
from trezor import log
log.debug = lambda name, msg, *args: None
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolWriter(unittest.TestCase):
short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
eight_longer_payloads_expected = [
b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
]
empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_with_checksum_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_write_empty_payload(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4)
gen = writer.write_payloads_to_wire(self.interface, header, (b"",))
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(len(self.interface.data), 0)
def test_write_short_payload(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 5)
data = b"\x07"
gen = writer.write_payloads_to_wire(self.interface, header, (data,))
gen.send(None)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
def test_write_longer_payload(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256)
gen = writer.write_payloads_to_wire(self.interface, header, (data,))
for i in range(5):
gen.send(None)
with self.assertRaises(StopIteration):
gen.send(None)
for i in range(len(self.longer_payload_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.longer_payload_expected[i]
)
def test_write_eight_longer_payloads(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 2048)
gen = writer.write_payloads_to_wire(
self.interface, header, (data, data, data, data, data, data, data, data)
)
for i in range(34):
gen.send(None)
with self.assertRaises(StopIteration):
gen.send(None)
for i in range(len(self.eight_longer_payloads_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
)
def test_write_empty_payload_with_checksum(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4)
gen = writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
gen.send(None)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(
hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
)
def test_write_longer_payload_with_checksum(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256)
gen = writer.write_payload_to_wire_and_add_checksum(
self.interface, header, data
)
for i in range(5):
gen.send(None)
with self.assertRaises(StopIteration):
gen.send(None)
for i in range(len(self.longer_payload_with_checksum_expected)):
self.assertEqual(
hexlify(self.interface.data[i]),
self.longer_payload_with_checksum_expected[i],
)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,352 @@
from common import *
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
import trezor.wire.thp
from trezor.wire.thp import alternating_bit_protocol as ABP
from trezor.wire.thp.writer import PACKET_LENGTH
from ubinascii import hexlify
import ustruct
from trezor import io, log, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v3
from trezor.wire.protocol_common import Message
from trezor.wire.thp import checksum
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
# Disable log.debug for the test
log.debug = lambda name, msg, *args: None
if TYPE_CHECKING:
from trezorio import WireInterface
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def printBytes(a):
print(hexlify(a).decode("utf-8"))
def getPlaintext() -> bytes:
if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
return PLAINTEXT_0
async def deprecated_read_message(
iface: WireInterface, buffer: utils.BufferType
) -> Message:
return Message(-1, b"\x00")
async def deprecated_write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
pass
# This test suite is an adaptation of test_trezor.wire.codec_v1
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
if not utils.USE_THP:
import storage.cache_thp # noqa: F401
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req_message)
gen.send(None)
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_v3.PACKET_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def _roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, 1
) # TODO use different session id
gen = deprecated_write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
printBytes(packet)
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def _write_one_packet(self):
message = Message(MESSAGE_TYPE, b"")
gen = deprecated_write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(MESSAGE_TYPE, message_payload)
gen = deprecated_write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_v3.PACKET_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_v3.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(trezor.wire.thp.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()

@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it,
it is just a single binary file that you can execute directly.
**Are you looking for a Trezor T emulator? This is most likely it.**
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317)
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456)
Build of our cryptographic library, which is then incorporated into the other builds.
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485)
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559)
Regular version (not only Bitcoin) of above.
**Are you looking for a Trezor One emulator? This is most likely it.**
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573)
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618)
Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses.
Bitcoin-only version.
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634)
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635)
---
## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml)

@ -191,6 +191,9 @@ void fsm_sendFailure(FailureType code, const char *text)
case FailureType_Failure_InvalidSession:
text = _("Invalid session");
break;
case FailureType_Failure_ThpUnallocatedSession:
text = _("Unallocated session");
break;
case FailureType_Failure_FirmwareError:
text = _("Firmware error");
break;

@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro
EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \
UnlockBootloader AuthenticateDevice AuthenticityProof \
Solana StellarClaimClaimableBalanceOp \
ChangeLanguage TranslationDataRequest TranslationDataAck \
ChangeLanguage TranslationDataRequest TranslationDataAck Thp \
SetBrightness DebugLinkOptigaSetSecMax \
ifeq ($(BITCOIN_ONLY), 1)

@ -0,0 +1 @@
../../vendor/trezor-common/protob/messages-thp.proto

@ -270,6 +270,24 @@ class MessageType(IntEnum):
SolanaAddress = 903
SolanaSignTx = 904
SolanaTxSignature = 905
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class FailureType(IntEnum):
@ -287,6 +305,7 @@ class FailureType(IntEnum):
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
FirmwareError = 99
@ -626,6 +645,13 @@ class TezosBallotType(IntEnum):
Pass = 2
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4
class BinanceGetAddress(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 700
FIELDS = {
@ -7753,6 +7779,328 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount
class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
bootloader_mode: Optional["bool"] = None,
protocol_version: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.bootloader_mode = bootloader_mode
self.protocol_version = protocol_version
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1001
FIELDS = {
1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
new_session_id: Optional["int"] = None,
) -> None:
self.new_session_id = new_session_id
class ThpStartPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
) -> None:
self.host_name = host_name
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1021
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcUnidirectionalTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
) -> None:
self.host_name = host_name
class ThpPairingCredential(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("cred_metadata", "ThpCredentialMetadata", repeated=False, required=False, default=None),
2: protobuf.Field("mac", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cred_metadata: Optional["ThpCredentialMetadata"] = None,
mac: Optional["bytes"] = None,
) -> None:
self.cred_metadata = cred_metadata
self.mac = mac
class ThpAuthenticatedCredentialData(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("cred_metadata", "ThpCredentialMetadata", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
cred_metadata: Optional["ThpCredentialMetadata"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
self.cred_metadata = cred_metadata
class WebAuthnListResidentCredentials(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 800

@ -80,6 +80,24 @@ trezor_message_impl! {
DebugLinkWatchLayout => MessageType_DebugLinkWatchLayout,
DebugLinkResetDebugEvents => MessageType_DebugLinkResetDebugEvents,
DebugLinkOptigaSetSecMax => MessageType_DebugLinkOptigaSetSecMax,
ThpCreateNewSession => MessageType_ThpCreateNewSession,
ThpNewSession => MessageType_ThpNewSession,
ThpStartPairingRequest => MessageType_ThpStartPairingRequest,
ThpPairingPreparationsFinished => MessageType_ThpPairingPreparationsFinished,
ThpCredentialRequest => MessageType_ThpCredentialRequest,
ThpCredentialResponse => MessageType_ThpCredentialResponse,
ThpEndRequest => MessageType_ThpEndRequest,
ThpEndResponse => MessageType_ThpEndResponse,
ThpCodeEntryCommitment => MessageType_ThpCodeEntryCommitment,
ThpCodeEntryChallenge => MessageType_ThpCodeEntryChallenge,
ThpCodeEntryCpaceHost => MessageType_ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor => MessageType_ThpCodeEntryCpaceTrezor,
ThpCodeEntryTag => MessageType_ThpCodeEntryTag,
ThpCodeEntrySecret => MessageType_ThpCodeEntrySecret,
ThpQrCodeTag => MessageType_ThpQrCodeTag,
ThpQrCodeSecret => MessageType_ThpQrCodeSecret,
ThpNfcUnidirectionalTag => MessageType_ThpNfcUnidirectionalTag,
ThpNfcUnidirectionalSecret => MessageType_ThpNfcUnidirectionalSecret,
}
#[cfg(feature = "binance")]

@ -510,6 +510,42 @@ pub enum MessageType {
MessageType_SolanaSignTx = 904,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_SolanaTxSignature)
MessageType_SolanaTxSignature = 905,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCreateNewSession)
MessageType_ThpCreateNewSession = 1000,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNewSession)
MessageType_ThpNewSession = 1001,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpStartPairingRequest)
MessageType_ThpStartPairingRequest = 1008,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpPairingPreparationsFinished)
MessageType_ThpPairingPreparationsFinished = 1009,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCredentialRequest)
MessageType_ThpCredentialRequest = 1010,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCredentialResponse)
MessageType_ThpCredentialResponse = 1011,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpEndRequest)
MessageType_ThpEndRequest = 1012,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpEndResponse)
MessageType_ThpEndResponse = 1013,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCommitment)
MessageType_ThpCodeEntryCommitment = 1016,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryChallenge)
MessageType_ThpCodeEntryChallenge = 1017,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCpaceHost)
MessageType_ThpCodeEntryCpaceHost = 1018,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryCpaceTrezor)
MessageType_ThpCodeEntryCpaceTrezor = 1019,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntryTag)
MessageType_ThpCodeEntryTag = 1020,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpCodeEntrySecret)
MessageType_ThpCodeEntrySecret = 1021,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpQrCodeTag)
MessageType_ThpQrCodeTag = 1024,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpQrCodeSecret)
MessageType_ThpQrCodeSecret = 1025,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNfcUnidirectionalTag)
MessageType_ThpNfcUnidirectionalTag = 1032,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.MessageType.MessageType_ThpNfcUnidirectionalSecret)
MessageType_ThpNfcUnidirectionalSecret = 1033,
}
impl ::protobuf::Enum for MessageType {
@ -762,6 +798,24 @@ impl ::protobuf::Enum for MessageType {
903 => ::std::option::Option::Some(MessageType::MessageType_SolanaAddress),
904 => ::std::option::Option::Some(MessageType::MessageType_SolanaSignTx),
905 => ::std::option::Option::Some(MessageType::MessageType_SolanaTxSignature),
1000 => ::std::option::Option::Some(MessageType::MessageType_ThpCreateNewSession),
1001 => ::std::option::Option::Some(MessageType::MessageType_ThpNewSession),
1008 => ::std::option::Option::Some(MessageType::MessageType_ThpStartPairingRequest),
1009 => ::std::option::Option::Some(MessageType::MessageType_ThpPairingPreparationsFinished),
1010 => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialRequest),
1011 => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialResponse),
1012 => ::std::option::Option::Some(MessageType::MessageType_ThpEndRequest),
1013 => ::std::option::Option::Some(MessageType::MessageType_ThpEndResponse),
1016 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCommitment),
1017 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryChallenge),
1018 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceHost),
1019 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceTrezor),
1020 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryTag),
1021 => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntrySecret),
1024 => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeTag),
1025 => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeSecret),
1032 => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalTag),
1033 => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalSecret),
_ => ::std::option::Option::None
}
}
@ -1009,6 +1063,24 @@ impl ::protobuf::Enum for MessageType {
"MessageType_SolanaAddress" => ::std::option::Option::Some(MessageType::MessageType_SolanaAddress),
"MessageType_SolanaSignTx" => ::std::option::Option::Some(MessageType::MessageType_SolanaSignTx),
"MessageType_SolanaTxSignature" => ::std::option::Option::Some(MessageType::MessageType_SolanaTxSignature),
"MessageType_ThpCreateNewSession" => ::std::option::Option::Some(MessageType::MessageType_ThpCreateNewSession),
"MessageType_ThpNewSession" => ::std::option::Option::Some(MessageType::MessageType_ThpNewSession),
"MessageType_ThpStartPairingRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpStartPairingRequest),
"MessageType_ThpPairingPreparationsFinished" => ::std::option::Option::Some(MessageType::MessageType_ThpPairingPreparationsFinished),
"MessageType_ThpCredentialRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialRequest),
"MessageType_ThpCredentialResponse" => ::std::option::Option::Some(MessageType::MessageType_ThpCredentialResponse),
"MessageType_ThpEndRequest" => ::std::option::Option::Some(MessageType::MessageType_ThpEndRequest),
"MessageType_ThpEndResponse" => ::std::option::Option::Some(MessageType::MessageType_ThpEndResponse),
"MessageType_ThpCodeEntryCommitment" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCommitment),
"MessageType_ThpCodeEntryChallenge" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryChallenge),
"MessageType_ThpCodeEntryCpaceHost" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceHost),
"MessageType_ThpCodeEntryCpaceTrezor" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryCpaceTrezor),
"MessageType_ThpCodeEntryTag" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntryTag),
"MessageType_ThpCodeEntrySecret" => ::std::option::Option::Some(MessageType::MessageType_ThpCodeEntrySecret),
"MessageType_ThpQrCodeTag" => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeTag),
"MessageType_ThpQrCodeSecret" => ::std::option::Option::Some(MessageType::MessageType_ThpQrCodeSecret),
"MessageType_ThpNfcUnidirectionalTag" => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalTag),
"MessageType_ThpNfcUnidirectionalSecret" => ::std::option::Option::Some(MessageType::MessageType_ThpNfcUnidirectionalSecret),
_ => ::std::option::Option::None
}
}
@ -1255,6 +1327,24 @@ impl ::protobuf::Enum for MessageType {
MessageType::MessageType_SolanaAddress,
MessageType::MessageType_SolanaSignTx,
MessageType::MessageType_SolanaTxSignature,
MessageType::MessageType_ThpCreateNewSession,
MessageType::MessageType_ThpNewSession,
MessageType::MessageType_ThpStartPairingRequest,
MessageType::MessageType_ThpPairingPreparationsFinished,
MessageType::MessageType_ThpCredentialRequest,
MessageType::MessageType_ThpCredentialResponse,
MessageType::MessageType_ThpEndRequest,
MessageType::MessageType_ThpEndResponse,
MessageType::MessageType_ThpCodeEntryCommitment,
MessageType::MessageType_ThpCodeEntryChallenge,
MessageType::MessageType_ThpCodeEntryCpaceHost,
MessageType::MessageType_ThpCodeEntryCpaceTrezor,
MessageType::MessageType_ThpCodeEntryTag,
MessageType::MessageType_ThpCodeEntrySecret,
MessageType::MessageType_ThpQrCodeTag,
MessageType::MessageType_ThpQrCodeSecret,
MessageType::MessageType_ThpNfcUnidirectionalTag,
MessageType::MessageType_ThpNfcUnidirectionalSecret,
];
}
@ -1507,6 +1597,24 @@ impl ::protobuf::EnumFull for MessageType {
MessageType::MessageType_SolanaAddress => 238,
MessageType::MessageType_SolanaSignTx => 239,
MessageType::MessageType_SolanaTxSignature => 240,
MessageType::MessageType_ThpCreateNewSession => 241,
MessageType::MessageType_ThpNewSession => 242,
MessageType::MessageType_ThpStartPairingRequest => 243,
MessageType::MessageType_ThpPairingPreparationsFinished => 244,
MessageType::MessageType_ThpCredentialRequest => 245,
MessageType::MessageType_ThpCredentialResponse => 246,
MessageType::MessageType_ThpEndRequest => 247,
MessageType::MessageType_ThpEndResponse => 248,
MessageType::MessageType_ThpCodeEntryCommitment => 249,
MessageType::MessageType_ThpCodeEntryChallenge => 250,
MessageType::MessageType_ThpCodeEntryCpaceHost => 251,
MessageType::MessageType_ThpCodeEntryCpaceTrezor => 252,
MessageType::MessageType_ThpCodeEntryTag => 253,
MessageType::MessageType_ThpCodeEntrySecret => 254,
MessageType::MessageType_ThpQrCodeTag => 255,
MessageType::MessageType_ThpQrCodeSecret => 256,
MessageType::MessageType_ThpNfcUnidirectionalTag => 257,
MessageType::MessageType_ThpNfcUnidirectionalSecret => 258,
};
Self::enum_descriptor().value_by_index(index)
}
@ -1541,6 +1649,14 @@ pub mod exts {
pub const wire_no_fsm: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50008, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const channel_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50009, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const channel_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50010, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const pairing_in: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50011, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const pairing_out: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(50012, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const bitcoin_only: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumValueOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(60000, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
pub const has_bitcoin_only_values: ::protobuf::ext::ExtFieldOptional<::protobuf::descriptor::EnumOptions, bool> = ::protobuf::ext::ExtFieldOptional::new(51001, ::protobuf::descriptor::field_descriptor_proto::Type::TYPE_BOOL);
@ -1556,7 +1672,7 @@ pub mod exts {
static file_descriptor_proto_data: &'static [u8] = b"\
\n\x0emessages.proto\x12\x12hw.trezor.messages\x1a\x20google/protobuf/de\
scriptor.proto*\xe2S\n\x0bMessageType\x12(\n\x16MessageType_Initialize\
scriptor.proto*\xc8Z\n\x0bMessageType\x12(\n\x16MessageType_Initialize\
\x10\0\x1a\x0c\x80\xa6\x1d\x01\xb0\xb5\x18\x01\x90\xb5\x18\x01\x12\x1e\n\
\x10MessageType_Ping\x10\x01\x1a\x08\x80\xa6\x1d\x01\x90\xb5\x18\x01\x12\
%\n\x13MessageType_Success\x10\x02\x1a\x0c\x80\xa6\x1d\x01\xa8\xb5\x18\
@ -1829,30 +1945,59 @@ static file_descriptor_proto_data: &'static [u8] = b"\
\x07\x1a\x04\x90\xb5\x18\x01\x12$\n\x19MessageType_SolanaAddress\x10\x87\
\x07\x1a\x04\x98\xb5\x18\x01\x12#\n\x18MessageType_SolanaSignTx\x10\x88\
\x07\x1a\x04\x90\xb5\x18\x01\x12(\n\x1dMessageType_SolanaTxSignature\x10\
\x89\x07\x1a\x04\x98\xb5\x18\x01\x1a\x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\
\\\"\x04\x08G\x10J\"\x04\x08r\x10z\"\x06\x08\xdb\x01\x10\xdb\x01\"\x06\
\x08\xe0\x01\x10\xe0\x01\"\x06\x08\xac\x02\x10\xb0\x02\"\x06\x08\xb5\x02\
\x10\xb8\x02:<\n\x07wire_in\x18\xd2\x86\x03\x20\x01(\x08\x12!.google.pro\
tobuf.EnumValueOptionsR\x06wireIn:>\n\x08wire_out\x18\xd3\x86\x03\x20\
\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x07wireOut:G\n\rwire_de\
bug_in\x18\xd4\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOption\
sR\x0bwireDebugIn:I\n\x0ewire_debug_out\x18\xd5\x86\x03\x20\x01(\x08\x12\
!.google.protobuf.EnumValueOptionsR\x0cwireDebugOut:@\n\twire_tiny\x18\
\xd6\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x08wire\
Tiny:L\n\x0fwire_bootloader\x18\xd7\x86\x03\x20\x01(\x08\x12!.google.pro\
tobuf.EnumValueOptionsR\x0ewireBootloader:C\n\x0bwire_no_fsm\x18\xd8\x86\
\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\twireNoFsm:F\n\
\x0cbitcoin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12!.google.protobuf.EnumV\
alueOptionsR\x0bbitcoinOnly:U\n\x17has_bitcoin_only_values\x18\xb9\x8e\
\x03\x20\x01(\x08\x12\x1c.google.protobuf.EnumOptionsR\x14hasBitcoinOnly\
Values:T\n\x14experimental_message\x18\xa1\x96\x03\x20\x01(\x08\x12\x1f.\
google.protobuf.MessageOptionsR\x13experimentalMessage:>\n\twire_type\
\x18\xa2\x96\x03\x20\x01(\r\x12\x1f.google.protobuf.MessageOptionsR\x08w\
ireType:N\n\x12experimental_field\x18\x89\x9e\x03\x20\x01(\x08\x12\x1d.g\
oogle.protobuf.FieldOptionsR\x11experimentalField:U\n\x17include_in_bitc\
oin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12\x1c.google.protobuf.FileOption\
sR\x14includeInBitcoinOnlyB8\n#com.satoshilabs.trezor.lib.protobufB\rTre\
zorMessage\x80\xa6\x1d\x01\
\x89\x07\x1a\x04\x98\xb5\x18\x01\x12.\n\x1fMessageType_ThpCreateNewSessi\
on\x10\xe8\x07\x1a\x08\x80\xa6\x1d\x01\xc8\xb5\x18\x01\x12(\n\x19Message\
Type_ThpNewSession\x10\xe9\x07\x1a\x08\x80\xa6\x1d\x01\xd0\xb5\x18\x01\
\x121\n\"MessageType_ThpStartPairingRequest\x10\xf0\x07\x1a\x08\x80\xa6\
\x1d\x01\xd8\xb5\x18\x01\x129\n*MessageType_ThpPairingPreparationsFinish\
ed\x10\xf1\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12/\n\x20Message\
Type_ThpCredentialRequest\x10\xf2\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\
\x18\x01\x120\n!MessageType_ThpCredentialResponse\x10\xf3\x07\x1a\x08\
\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12(\n\x19MessageType_ThpEndRequest\x10\
\xf4\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12)\n\x1aMessageType_T\
hpEndResponse\x10\xf5\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x121\n\
\"MessageType_ThpCodeEntryCommitment\x10\xf8\x07\x1a\x08\x80\xa6\x1d\x01\
\xe0\xb5\x18\x01\x120\n!MessageType_ThpCodeEntryChallenge\x10\xf9\x07\
\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x120\n!MessageType_ThpCodeEntry\
CpaceHost\x10\xfa\x07\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x122\n#Mes\
sageType_ThpCodeEntryCpaceTrezor\x10\xfb\x07\x1a\x08\x80\xa6\x1d\x01\xe0\
\xb5\x18\x01\x12*\n\x1bMessageType_ThpCodeEntryTag\x10\xfc\x07\x1a\x08\
\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12-\n\x1eMessageType_ThpCodeEntrySecre\
t\x10\xfd\x07\x1a\x08\x80\xa6\x1d\x01\xe0\xb5\x18\x01\x12'\n\x18MessageT\
ype_ThpQrCodeTag\x10\x80\x08\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x12\
*\n\x1bMessageType_ThpQrCodeSecret\x10\x81\x08\x1a\x08\x80\xa6\x1d\x01\
\xe0\xb5\x18\x01\x122\n#MessageType_ThpNfcUnidirectionalTag\x10\x88\x08\
\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x125\n&MessageType_ThpNfcUnidir\
ectionalSecret\x10\x89\x08\x1a\x08\x80\xa6\x1d\x01\xd8\xb5\x18\x01\x1a\
\x04\xc8\xf3\x18\x01\"\x04\x08Z\x10\\\"\x04\x08G\x10J\"\x04\x08r\x10z\"\
\x06\x08\xdb\x01\x10\xdb\x01\"\x06\x08\xe0\x01\x10\xe0\x01\"\x06\x08\xac\
\x02\x10\xb0\x02\"\x06\x08\xb5\x02\x10\xb8\x02:<\n\x07wire_in\x18\xd2\
\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x06wireIn:>\
\n\x08wire_out\x18\xd3\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumVal\
ueOptionsR\x07wireOut:G\n\rwire_debug_in\x18\xd4\x86\x03\x20\x01(\x08\
\x12!.google.protobuf.EnumValueOptionsR\x0bwireDebugIn:I\n\x0ewire_debug\
_out\x18\xd5\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\
\x0cwireDebugOut:@\n\twire_tiny\x18\xd6\x86\x03\x20\x01(\x08\x12!.google\
.protobuf.EnumValueOptionsR\x08wireTiny:L\n\x0fwire_bootloader\x18\xd7\
\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\x0ewireBoot\
loader:C\n\x0bwire_no_fsm\x18\xd8\x86\x03\x20\x01(\x08\x12!.google.proto\
buf.EnumValueOptionsR\twireNoFsm:B\n\nchannel_in\x18\xd9\x86\x03\x20\x01\
(\x08\x12!.google.protobuf.EnumValueOptionsR\tchannelIn:D\n\x0bchannel_o\
ut\x18\xda\x86\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\n\
channelOut:B\n\npairing_in\x18\xdb\x86\x03\x20\x01(\x08\x12!.google.prot\
obuf.EnumValueOptionsR\tpairingIn:D\n\x0bpairing_out\x18\xdc\x86\x03\x20\
\x01(\x08\x12!.google.protobuf.EnumValueOptionsR\npairingOut:F\n\x0cbitc\
oin_only\x18\xe0\xd4\x03\x20\x01(\x08\x12!.google.protobuf.EnumValueOpti\
onsR\x0bbitcoinOnly:U\n\x17has_bitcoin_only_values\x18\xb9\x8e\x03\x20\
\x01(\x08\x12\x1c.google.protobuf.EnumOptionsR\x14hasBitcoinOnlyValues:T\
\n\x14experimental_message\x18\xa1\x96\x03\x20\x01(\x08\x12\x1f.google.p\
rotobuf.MessageOptionsR\x13experimentalMessage:>\n\twire_type\x18\xa2\
\x96\x03\x20\x01(\r\x12\x1f.google.protobuf.MessageOptionsR\x08wireType:\
N\n\x12experimental_field\x18\x89\x9e\x03\x20\x01(\x08\x12\x1d.google.pr\
otobuf.FieldOptionsR\x11experimentalField:U\n\x17include_in_bitcoin_only\
\x18\xe0\xd4\x03\x20\x01(\x08\x12\x1c.google.protobuf.FileOptionsR\x14in\
cludeInBitcoinOnlyB8\n#com.satoshilabs.trezor.lib.protobufB\rTrezorMessa\
ge\x80\xa6\x1d\x01\
";
/// `FileDescriptorProto` object which was a source for this generated file

@ -414,6 +414,8 @@ pub mod failure {
Failure_WipeCodeMismatch = 13,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_InvalidSession)
Failure_InvalidSession = 14,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_ThpUnallocatedSession)
Failure_ThpUnallocatedSession = 15,
// @@protoc_insertion_point(enum_value:hw.trezor.messages.common.Failure.FailureType.Failure_FirmwareError)
Failure_FirmwareError = 99,
}
@ -441,6 +443,7 @@ pub mod failure {
12 => ::std::option::Option::Some(FailureType::Failure_PinMismatch),
13 => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch),
14 => ::std::option::Option::Some(FailureType::Failure_InvalidSession),
15 => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession),
99 => ::std::option::Option::Some(FailureType::Failure_FirmwareError),
_ => ::std::option::Option::None
}
@ -462,6 +465,7 @@ pub mod failure {
"Failure_PinMismatch" => ::std::option::Option::Some(FailureType::Failure_PinMismatch),
"Failure_WipeCodeMismatch" => ::std::option::Option::Some(FailureType::Failure_WipeCodeMismatch),
"Failure_InvalidSession" => ::std::option::Option::Some(FailureType::Failure_InvalidSession),
"Failure_ThpUnallocatedSession" => ::std::option::Option::Some(FailureType::Failure_ThpUnallocatedSession),
"Failure_FirmwareError" => ::std::option::Option::Some(FailureType::Failure_FirmwareError),
_ => ::std::option::Option::None
}
@ -482,6 +486,7 @@ pub mod failure {
FailureType::Failure_PinMismatch,
FailureType::Failure_WipeCodeMismatch,
FailureType::Failure_InvalidSession,
FailureType::Failure_ThpUnallocatedSession,
FailureType::Failure_FirmwareError,
];
}
@ -508,7 +513,8 @@ pub mod failure {
FailureType::Failure_PinMismatch => 11,
FailureType::Failure_WipeCodeMismatch => 12,
FailureType::Failure_InvalidSession => 13,
FailureType::Failure_FirmwareError => 14,
FailureType::Failure_ThpUnallocatedSession => 14,
FailureType::Failure_FirmwareError => 15,
};
Self::enum_descriptor().value_by_index(index)
}
@ -2481,9 +2487,9 @@ impl ::protobuf::reflect::ProtobufValue for HDNodeType {
static file_descriptor_proto_data: &'static [u8] = b"\
\n\x15messages-common.proto\x12\x19hw.trezor.messages.common\x1a\x0emess\
ages.proto\"%\n\x07Success\x12\x1a\n\x07message\x18\x01\x20\x01(\t:\0R\
\x07message\"\x8f\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2.\
\x07message\"\xb2\x04\n\x07Failure\x12B\n\x04code\x18\x01\x20\x01(\x0e2.\
.hw.trezor.messages.common.Failure.FailureTypeR\x04code\x12\x18\n\x07mes\
sage\x18\x02\x20\x01(\tR\x07message\"\xa5\x03\n\x0bFailureType\x12\x1d\n\
sage\x18\x02\x20\x01(\tR\x07message\"\xc8\x03\n\x0bFailureType\x12\x1d\n\
\x19Failure_UnexpectedMessage\x10\x01\x12\x1a\n\x16Failure_ButtonExpecte\
d\x10\x02\x12\x15\n\x11Failure_DataError\x10\x03\x12\x1b\n\x17Failure_Ac\
tionCancelled\x10\x04\x12\x17\n\x13Failure_PinExpected\x10\x05\x12\x18\n\
@ -2492,44 +2498,45 @@ static file_descriptor_proto_data: &'static [u8] = b"\
essError\x10\t\x12\x1a\n\x16Failure_NotEnoughFunds\x10\n\x12\x1a\n\x16Fa\
ilure_NotInitialized\x10\x0b\x12\x17\n\x13Failure_PinMismatch\x10\x0c\
\x12\x1c\n\x18Failure_WipeCodeMismatch\x10\r\x12\x1a\n\x16Failure_Invali\
dSession\x10\x0e\x12\x19\n\x15Failure_FirmwareError\x10c\"\xab\x06\n\rBu\
ttonRequest\x12N\n\x04code\x18\x01\x20\x01(\x0e2:.hw.trezor.messages.com\
mon.ButtonRequest.ButtonRequestTypeR\x04code\x12\x14\n\x05pages\x18\x02\
\x20\x01(\rR\x05pages\x12\x12\n\x04name\x18\x04\x20\x01(\tR\x04name\"\
\x99\x05\n\x11ButtonRequestType\x12\x17\n\x13ButtonRequest_Other\x10\x01\
\x12\"\n\x1eButtonRequest_FeeOverThreshold\x10\x02\x12\x1f\n\x1bButtonRe\
quest_ConfirmOutput\x10\x03\x12\x1d\n\x19ButtonRequest_ResetDevice\x10\
\x04\x12\x1d\n\x19ButtonRequest_ConfirmWord\x10\x05\x12\x1c\n\x18ButtonR\
equest_WipeDevice\x10\x06\x12\x1d\n\x19ButtonRequest_ProtectCall\x10\x07\
\x12\x18\n\x14ButtonRequest_SignTx\x10\x08\x12\x1f\n\x1bButtonRequest_Fi\
rmwareCheck\x10\t\x12\x19\n\x15ButtonRequest_Address\x10\n\x12\x1b\n\x17\
ButtonRequest_PublicKey\x10\x0b\x12#\n\x1fButtonRequest_MnemonicWordCoun\
t\x10\x0c\x12\x1f\n\x1bButtonRequest_MnemonicInput\x10\r\x120\n(_Depreca\
ted_ButtonRequest_PassphraseType\x10\x0e\x1a\x02\x08\x01\x12'\n#ButtonRe\
quest_UnknownDerivationPath\x10\x0f\x12\"\n\x1eButtonRequest_RecoveryHom\
epage\x10\x10\x12\x19\n\x15ButtonRequest_Success\x10\x11\x12\x19\n\x15Bu\
ttonRequest_Warning\x10\x12\x12!\n\x1dButtonRequest_PassphraseEntry\x10\
\x13\x12\x1a\n\x16ButtonRequest_PinEntry\x10\x14J\x04\x08\x03\x10\x04\"\
\x0b\n\tButtonAck\"\xbb\x02\n\x10PinMatrixRequest\x12T\n\x04type\x18\x01\
\x20\x01(\x0e2@.hw.trezor.messages.common.PinMatrixRequest.PinMatrixRequ\
estTypeR\x04type\"\xd0\x01\n\x14PinMatrixRequestType\x12\x20\n\x1cPinMat\
rixRequestType_Current\x10\x01\x12!\n\x1dPinMatrixRequestType_NewFirst\
\x10\x02\x12\"\n\x1ePinMatrixRequestType_NewSecond\x10\x03\x12&\n\"PinMa\
trixRequestType_WipeCodeFirst\x10\x04\x12'\n#PinMatrixRequestType_WipeCo\
deSecond\x10\x05\"\x20\n\x0cPinMatrixAck\x12\x10\n\x03pin\x18\x01\x20\
\x02(\tR\x03pin\"5\n\x11PassphraseRequest\x12\x20\n\n_on_device\x18\x01\
\x20\x01(\x08R\x08OnDeviceB\x02\x18\x01\"g\n\rPassphraseAck\x12\x1e\n\np\
assphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x19\n\x06_state\x18\x02\
\x20\x01(\x0cR\x05StateB\x02\x18\x01\x12\x1b\n\ton_device\x18\x03\x20\
\x01(\x08R\x08onDevice\"=\n!Deprecated_PassphraseStateRequest\x12\x14\n\
\x05state\x18\x01\x20\x01(\x0cR\x05state:\x02\x18\x01\"#\n\x1dDeprecated\
_PassphraseStateAck:\x02\x18\x01\"\xc0\x01\n\nHDNodeType\x12\x14\n\x05de\
pth\x18\x01\x20\x02(\rR\x05depth\x12\x20\n\x0bfingerprint\x18\x02\x20\
\x02(\rR\x0bfingerprint\x12\x1b\n\tchild_num\x18\x03\x20\x02(\rR\x08chil\
dNum\x12\x1d\n\nchain_code\x18\x04\x20\x02(\x0cR\tchainCode\x12\x1f\n\
\x0bprivate_key\x18\x05\x20\x01(\x0cR\nprivateKey\x12\x1d\n\npublic_key\
\x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#com.satoshilabs.trezor.lib.protobu\
fB\x13TrezorMessageCommon\x80\xa6\x1d\x01\
dSession\x10\x0e\x12!\n\x1dFailure_ThpUnallocatedSession\x10\x0f\x12\x19\
\n\x15Failure_FirmwareError\x10c\"\xab\x06\n\rButtonRequest\x12N\n\x04co\
de\x18\x01\x20\x01(\x0e2:.hw.trezor.messages.common.ButtonRequest.Button\
RequestTypeR\x04code\x12\x14\n\x05pages\x18\x02\x20\x01(\rR\x05pages\x12\
\x12\n\x04name\x18\x04\x20\x01(\tR\x04name\"\x99\x05\n\x11ButtonRequestT\
ype\x12\x17\n\x13ButtonRequest_Other\x10\x01\x12\"\n\x1eButtonRequest_Fe\
eOverThreshold\x10\x02\x12\x1f\n\x1bButtonRequest_ConfirmOutput\x10\x03\
\x12\x1d\n\x19ButtonRequest_ResetDevice\x10\x04\x12\x1d\n\x19ButtonReque\
st_ConfirmWord\x10\x05\x12\x1c\n\x18ButtonRequest_WipeDevice\x10\x06\x12\
\x1d\n\x19ButtonRequest_ProtectCall\x10\x07\x12\x18\n\x14ButtonRequest_S\
ignTx\x10\x08\x12\x1f\n\x1bButtonRequest_FirmwareCheck\x10\t\x12\x19\n\
\x15ButtonRequest_Address\x10\n\x12\x1b\n\x17ButtonRequest_PublicKey\x10\
\x0b\x12#\n\x1fButtonRequest_MnemonicWordCount\x10\x0c\x12\x1f\n\x1bButt\
onRequest_MnemonicInput\x10\r\x120\n(_Deprecated_ButtonRequest_Passphras\
eType\x10\x0e\x1a\x02\x08\x01\x12'\n#ButtonRequest_UnknownDerivationPath\
\x10\x0f\x12\"\n\x1eButtonRequest_RecoveryHomepage\x10\x10\x12\x19\n\x15\
ButtonRequest_Success\x10\x11\x12\x19\n\x15ButtonRequest_Warning\x10\x12\
\x12!\n\x1dButtonRequest_PassphraseEntry\x10\x13\x12\x1a\n\x16ButtonRequ\
est_PinEntry\x10\x14J\x04\x08\x03\x10\x04\"\x0b\n\tButtonAck\"\xbb\x02\n\
\x10PinMatrixRequest\x12T\n\x04type\x18\x01\x20\x01(\x0e2@.hw.trezor.mes\
sages.common.PinMatrixRequest.PinMatrixRequestTypeR\x04type\"\xd0\x01\n\
\x14PinMatrixRequestType\x12\x20\n\x1cPinMatrixRequestType_Current\x10\
\x01\x12!\n\x1dPinMatrixRequestType_NewFirst\x10\x02\x12\"\n\x1ePinMatri\
xRequestType_NewSecond\x10\x03\x12&\n\"PinMatrixRequestType_WipeCodeFirs\
t\x10\x04\x12'\n#PinMatrixRequestType_WipeCodeSecond\x10\x05\"\x20\n\x0c\
PinMatrixAck\x12\x10\n\x03pin\x18\x01\x20\x02(\tR\x03pin\"5\n\x11Passphr\
aseRequest\x12\x20\n\n_on_device\x18\x01\x20\x01(\x08R\x08OnDeviceB\x02\
\x18\x01\"g\n\rPassphraseAck\x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\n\
passphrase\x12\x19\n\x06_state\x18\x02\x20\x01(\x0cR\x05StateB\x02\x18\
\x01\x12\x1b\n\ton_device\x18\x03\x20\x01(\x08R\x08onDevice\"=\n!Depreca\
ted_PassphraseStateRequest\x12\x14\n\x05state\x18\x01\x20\x01(\x0cR\x05s\
tate:\x02\x18\x01\"#\n\x1dDeprecated_PassphraseStateAck:\x02\x18\x01\"\
\xc0\x01\n\nHDNodeType\x12\x14\n\x05depth\x18\x01\x20\x02(\rR\x05depth\
\x12\x20\n\x0bfingerprint\x18\x02\x20\x02(\rR\x0bfingerprint\x12\x1b\n\t\
child_num\x18\x03\x20\x02(\rR\x08childNum\x12\x1d\n\nchain_code\x18\x04\
\x20\x02(\x0cR\tchainCode\x12\x1f\n\x0bprivate_key\x18\x05\x20\x01(\x0cR\
\nprivateKey\x12\x1d\n\npublic_key\x18\x06\x20\x02(\x0cR\tpublicKeyB>\n#\
com.satoshilabs.trezor.lib.protobufB\x13TrezorMessageCommon\x80\xa6\x1d\
\x01\
";
/// `FileDescriptorProto` object which was a source for this generated file

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save