1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-13 02:58:57 +00:00

feat(core): implement a new Trezor-Host protocol

This commit is contained in:
M1nd3r 2024-07-22 08:40:27 +02:00 committed by matejcik
parent b02c7c4895
commit f00011d480
121 changed files with 12447 additions and 1243 deletions

View File

@ -307,6 +307,7 @@ core unix frozen debug build:
needs: []
variables:
PYOPT: "0"
THP: "1"
script:
- $NIX_SHELL --run "poetry run make -C core build_unix_frozen"
artifacts:

View File

@ -39,6 +39,7 @@ message Failure {
Failure_PinMismatch = 12;
Failure_WipeCodeMismatch = 13;
Failure_InvalidSession = 14;
Failure_ThpUnallocatedSession=15;
Failure_FirmwareError = 99;
}
}

View File

@ -110,6 +110,8 @@ message DebugLinkGetState {
// trezor-core only - wait until current layout changes
// changed in 2.6.4: multiple wait types instead of true/false.
optional DebugWaitType wait_layout = 3 [default=IMMEDIATE];
optional uint32 thp_channel_id=4; // THP only - used to get information from particular channel
}
/**
@ -130,6 +132,8 @@ message DebugLinkState {
optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow
optional management.BackupType mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39)
repeated string tokens = 13; // current layout represented as a list of string tokens
optional uint32 thp_pairing_code_entry_code = 14;
optional bytes thp_pairing_secret = 15;
}
/**

View File

@ -0,0 +1,219 @@
syntax = "proto2";
package hw.trezor.messages.thp;
// Sugar for easier handling in Java
option java_package = "com.satoshilabs.trezor.lib.protobuf";
option java_outer_classname = "TrezorMessageThp";
option (include_in_bitcoin_only) = true;
import "messages.proto";
/**
* Numeric identifiers of pairing methods.
* @embed
*/
enum ThpPairingMethod {
NoMethod = 1; // Trust without MITM protection.
CodeEntry = 2; // User types code diplayed on Trezor into the host application.
QrCode = 3; // User scans code displayed on Trezor into host application.
NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
}
/**
* @embed
*/
message ThpDeviceProperties {
optional string internal_model = 1; // Internal model name e.g. "T2B1".
optional uint32 model_variant = 2; // Encodes the device properties such as color.
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
}
/**
* @embed
*/
message ThpHandshakeCompletionReqNoisePayload {
optional bytes host_pairing_credential = 1; // Host's pairing credential
repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host
}
/**
* Request: Ask device for a new session with given passphrase.
* @start
* @next ThpNewSession
*/
message ThpCreateNewSession{
optional string passphrase = 1;
optional bool on_device = 2; // User wants to enter passphrase on the device
optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only
}
/**
* Response: Contains session_id of the newly created session.
* @end
*/
message ThpNewSession{
optional uint32 new_session_id = 1;
}
/**
* Request: Start pairing process.
* @start
* @next ThpCodeEntryCommitment
* @next ThpPairingPreparationsFinished
*/
message ThpStartPairingRequest{
optional string host_name = 1; // Human-readable host name
}
/**
* Response: Pairing is ready for user input / OOB communication.
* @next ThpCodeEntryCpace
* @next ThpQrCodeTag
* @next ThpNfcUnidirectionalTag
*/
message ThpPairingPreparationsFinished{
}
/**
* Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment.
* @next ThpCodeEntryChallenge
*/
message ThpCodeEntryCommitment {
optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret
}
/**
* Response: Host responds to Trezor's Code Entry commitment with a challenge.
* @next ThpPairingPreparationsFinished
*/
message ThpCodeEntryChallenge {
optional bytes challenge = 1; // host's random 32-byte challenge
}
/**
* Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor.
* @next ThpCodeEntryCpaceTrezor
*/
message ThpCodeEntryCpaceHost {
optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key
}
/**
* Response: Trezor continues with the CPACE protocol.
* @next ThpCodeEntryTag
*/
message ThpCodeEntryCpaceTrezor {
optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key
}
/**
* Response: Host continues with the CPACE protocol.
* @next ThpCodeEntrySecret
*/
message ThpCodeEntryTag {
optional bytes tag = 2; // SHA-256 of shared secret
}
/**
* Response: Trezor finishes the CPACE protocol.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCodeEntrySecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected QR Code pairing option. Host sends a QR Tag.
* @next ThpQrCodeSecret
*/
message ThpQrCodeTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the QR secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpQrCodeSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag.
* @next ThpNfcUnidirectionalSecret
*/
message ThpNfcUnidirectionalTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the Unidirectioal NFC secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpNfcUnidirectionalSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: Host requests issuance of a new pairing credential.
* @start
* @next ThpCredentialResponse
*/
message ThpCredentialRequest {
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake.
}
/**
* Response: Trezor issues a new pairing credential.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCredentialResponse {
optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake.
optional bytes credential = 2; // The pairing credential issued by the Trezor to the host.
}
/**
* Request: Host requests transition to the encrypted traffic phase.
* @start
* @next ThpEndResponse
*/
message ThpEndRequest {}
/**
* Response: Trezor approves transition to the encrypted traffic phase
* @end
*/
message ThpEndResponse {}
/**
* Only for internal use.
* @embed
*/
message ThpCredentialMetadata {
optional string host_name = 1; // Human-readable host name
}
/**
* Only for internal use.
* @embed
*/
message ThpPairingCredential {
optional ThpCredentialMetadata cred_metadata = 1; // Credential metadata
optional bytes mac = 2; // Message authentication code generated by the Trezor
}
/**
* Only for internal use.
* @embed
*/
message ThpAuthenticatedCredentialData {
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake
optional ThpCredentialMetadata cred_metadata = 2; // Credential metadata
}

View File

@ -42,6 +42,10 @@ extend google.protobuf.EnumValueOptions {
optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode
optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader
optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode
optional bool channel_in = 50009;
optional bool channel_out = 50010;
optional bool pairing_in = 50011;
optional bool pairing_out = 50012;
optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build
// (messages not marked bitcoin_only will be EXCLUDED)
@ -375,4 +379,26 @@ enum MessageType {
MessageType_SolanaAddress = 903 [(wire_out) = true];
MessageType_SolanaSignTx = 904 [(wire_in) = true];
MessageType_SolanaTxSignature = 905 [(wire_out) = true];
// THP
MessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true];
MessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true];
MessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true];
MessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true];
MessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true];
MessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true];
MessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true];
}

View File

@ -558,6 +558,8 @@ class RustBlobRenderer:
enums = []
cursor = 0
for enum in sorted(self.descriptor.enums, key=lambda e: e.name):
if enum.name == "MessageType":
continue
self.enum_map[enum.name] = cursor
enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value))
enums.append(enum_blob)

View File

@ -291,10 +291,16 @@ build_unix: templates ## build unix port
build_unix_frozen: templates build_cross ## build unix port with frozen modules
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
PYOPT="$(PYOPT)" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 NEW_RENDERING="$(NEW_RENDERING)"
build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0")
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1
build_unix_debug: templates ## build unix port
$(SCONS) --max-drift=1 CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \

View File

@ -17,14 +17,20 @@ HW_REVISION = ARGUMENTS.get('HW_REVISION', None)
THP = ARGUMENTS.get('THP', '0') == '1' # Trezor-Host Protocol
NEW_RENDERING = ARGUMENTS.get('NEW_RENDERING', '1') == '1' or TREZOR_MODEL in ('T3T1',)
FEATURE_FLAGS = {
"RDI": True,
"SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot)
"SYSTEM_VIEW": False,
"AES_GCM": False,
}
if THP:
FEATURE_FLAGS = {
"RDI": True,
"SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot)
"SYSTEM_VIEW": False,
"AES_GCM": True, # Required for THP encryption
}
else:
FEATURE_FLAGS = {
"RDI": True,
"SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot)
"SYSTEM_VIEW": False,
"AES_GCM": False,
}
FEATURES_WANTED = ["input", "sbu", "sd_card", "rgb_led", "dma2d", "consumption_mask", "usb" ,"optiga", "haptic"]
if DISABLE_OPTIGA and PYOPT == '0':
FEATURES_WANTED.remove("optiga")
@ -653,6 +659,8 @@ if FROZEN:
else:
raise ValueError('Unknown layout')
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
@ -702,6 +710,8 @@ if FROZEN:
SOURCE_PY_DIR + 'apps/bitcoin/sign_tx/zcash_v4.py',
])
)
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
if EVERYTHING:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/binance/*.py'))
@ -779,7 +789,7 @@ if FROZEN:
source_files = SOURCE_MOD + SOURCE_MOD_CRYPTO + SOURCE_FIRMWARE + SOURCE_MICROPYTHON + SOURCE_MICROPYTHON_SPEED + SOURCE_HAL
obj_program = []
obj_program.extend(env.Object(source=SOURCE_MOD))
obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO, CCFLAGS='$CCFLAGS -ftrivial-auto-var-init=zero'))
obj_program.extend(env.Object(source=SOURCE_MOD_CRYPTO))
if FEATURE_FLAGS["SECP256K1_ZKP"]:
obj_program.extend(env.Object(source=SOURCE_MOD_SECP256K1_ZKP, CCFLAGS='$CCFLAGS -Wno-unused-function'))
source_files.extend(SOURCE_MOD_SECP256K1_ZKP)

View File

@ -684,6 +684,8 @@ if FROZEN:
else:
raise ValueError('Unknown layout')
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
@ -777,6 +779,9 @@ if FROZEN:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py'))
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py'))

View File

@ -111,9 +111,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_obj,
mod_trezorcrypto_AesGcm_encrypt);
/// def encrypt_in_place(self, data: bytearray | memoryview) -> int:
/// """
/// Encrypt data chunk in place. Returns the length of the encrypted data.
/// """
/// """
/// Encrypt data chunk in place. Returns the length of the encrypted data.
/// """
STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt_in_place(mp_obj_t self,
mp_obj_t data) {
mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self);
@ -158,9 +158,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_obj,
mod_trezorcrypto_AesGcm_decrypt);
/// def decrypt_in_place(self, data: bytearray | memoryview) -> int:
/// """
/// Decrypt data chunk in place. Returns the length of the decrypted data.
/// """
/// """
/// Decrypt data chunk in place. Returns the length of the decrypted data.
/// """
STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt_in_place(mp_obj_t self,
mp_obj_t data) {
mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self);

View File

@ -441,7 +441,7 @@ STATIC mp_obj_tuple_t mod_trezorutils_version_obj = {
/// UI_LAYOUT: str
/// """UI layout identifier ("tt" for model T, "tr" for models One and R)."""
/// USE_THP: bool
/// """Whether the firmware supports Trezor-Host Protocol (version 3)."""
/// """Whether the firmware supports the Trezor-Host Protocol."""
STATIC const mp_rom_map_elem_t mp_module_trezorutils_globals_table[] = {
{MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_trezorutils)},

View File

@ -356,7 +356,7 @@ pub static mp_module_trezorproto: Module = obj_module! {
/// """Calculate length of encoding of the specified message."""
Qstr::MP_QSTR_encoded_length => obj_fn_1!(protobuf_len).as_obj(),
/// def encode(buffer: bytearray, msg: MessageType) -> int:
/// def encode(buffer: bytearray | memoryview, msg: MessageType) -> int:
/// """Encode the message into the specified buffer. Return length of
/// encoding."""
Qstr::MP_QSTR_encode => obj_fn_2!(protobuf_encode).as_obj()

View File

@ -55,9 +55,9 @@ class aesgcm:
"""
def encrypt_in_place(self, data: bytearray | memoryview) -> int:
"""
Encrypt data chunk in place. Returns the length of the encrypted data.
"""
"""
Encrypt data chunk in place. Returns the length of the encrypted data.
"""
def decrypt(self, data: bytes) -> bytes:
"""
@ -65,9 +65,9 @@ class aesgcm:
"""
def decrypt_in_place(self, data: bytearray | memoryview) -> int:
"""
Decrypt data chunk in place. Returns the length of the decrypted data.
"""
"""
Decrypt data chunk in place. Returns the length of the decrypted data.
"""
def auth(self, data: bytes) -> None:
"""

View File

@ -42,6 +42,6 @@ def encoded_length(msg: MessageType) -> int:
# rust/src/protobuf/obj.rs
def encode(buffer: bytearray, msg: MessageType) -> int:
def encode(buffer: bytearray | memoryview, msg: MessageType) -> int:
"""Encode the message into the specified buffer. Return length of
encoding."""

View File

@ -151,4 +151,4 @@ BITCOIN_ONLY: bool
UI_LAYOUT: str
"""UI layout identifier ("tt" for model T, "tr" for models One and R)."""
USE_THP: bool
"""Whether the firmware supports Trezor-Host Protocol (version 3)."""
"""Whether the firmware supports the Trezor-Host Protocol."""

View File

@ -47,6 +47,12 @@ storage
import storage
storage.cache
import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common
import storage.common
storage.debug
@ -135,6 +141,8 @@ trezor.enums.SafetyCheckLevel
import trezor.enums.SafetyCheckLevel
trezor.enums.SdProtectOperationType
import trezor.enums.SdProtectOperationType
trezor.enums.ThpPairingMethod
import trezor.enums.ThpPairingMethod
trezor.enums.WordRequestType
import trezor.enums.WordRequestType
trezor.enums
@ -205,6 +213,48 @@ trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.message_handler
import trezor.wire.message_handler
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.wire.thp
import trezor.wire.thp
trezor.wire.thp.alternating_bit_protocol
import trezor.wire.thp.alternating_bit_protocol
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.cpace
import trezor.wire.thp.cpace
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.handler_provider
import trezor.wire.thp.handler_provider
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages
trezor.wire.thp.transmission_loop
import trezor.wire.thp.transmission_loop
trezor.wire.thp.writer
import trezor.wire.thp.writer
trezor.wire.thp_main
import trezor.wire.thp_main
trezor.workflow
import trezor.workflow
apps
@ -289,6 +339,8 @@ apps.common.backup
import apps.common.backup
apps.common.backup_types
import apps.common.backup_types
apps.common.cache
import apps.common.cache
apps.common.cbor
import apps.common.cbor
apps.common.coininfo
@ -377,6 +429,14 @@ apps.misc.get_firmware_hash
import apps.misc.get_firmware_hash
apps.misc.sign_identity
import apps.misc.sign_identity
apps.thp
import apps.thp
apps.thp.create_session
import apps.thp.create_session
apps.thp.credential_manager
import apps.thp.credential_manager
apps.thp.pairing
import apps.thp.pairing
apps.workflow_handlers
import apps.workflow_handlers

View File

@ -1,11 +1,15 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_codec as cache_codec
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED
from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType
from trezor.messages import Success, UnlockPath
from trezor.ui.layouts import confirm_action
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
from . import workflow_handlers
@ -34,7 +38,7 @@ def busy_expiry_ms() -> int:
Returns the time left until the busy state expires or 0 if the device is not in the busy state.
"""
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS)
if busy_deadline_ms is None:
return 0
@ -199,13 +203,20 @@ def get_features() -> Features:
return f
async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
@check_thp_is_not_used
async def handle_Initialize(
msg: Initialize,
) -> Features:
session_id = cache_codec.start_session(msg.session_id)
# TODO change cardano derivation
# ctx = context.get_context()
if not utils.BITCOIN_ONLY:
derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
from storage.cache_common import APP_COMMON_DERIVE_CARDANO
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
have_seed = context.cache_is_set(APP_COMMON_SEED)
if (
have_seed
and msg.derive_cardano is not None
@ -213,14 +224,12 @@ async def handle_Initialize(msg: Initialize) -> Features:
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
storage_cache.end_current_session()
session_id = storage_cache.start_session()
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
if not have_seed:
storage_cache.set_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
features = get_features()
features.session_id = session_id
@ -249,16 +258,16 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
import utime
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline)
else:
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
workflow.close_others()
return Success()
async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session()
cache_codec.end_current_session()
return Success()
@ -273,7 +282,7 @@ async def handle_Ping(msg: Ping) -> Success:
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common import authorization
@ -286,11 +295,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
if handler is None:
return wire.unexpected_message()
return wire.message_handler.unexpected_message()
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
@ -298,7 +305,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.wire.context import call_any, get_context
from trezor.wire.context import call_any
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
@ -339,9 +346,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
get_context().iface, req.MESSAGE_WIRE_TYPE
)
handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
assert handler is not None
return await handler(req, msg) # type: ignore [Expected 1 positional argument]
@ -361,7 +366,7 @@ def set_homescreen() -> None:
set_default = workflow.set_default # local_cache_attribute
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS):
from apps.homescreen import busyscreen
set_default(busyscreen)
@ -390,7 +395,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin():
config.lock()
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)
set_homescreen()
if interrupt_workflow:
workflow.close_others()
@ -426,7 +431,7 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False
set_homescreen()
wire.remove_filter(_pinlock_filter)
remove_filter(_pinlock_filter)
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
@ -479,4 +484,4 @@ def boot() -> None:
backup.activate_repeated_backup()
if not config.is_unlocked():
# pinlocked handler should always be the last one
wire.filters.append(_pinlock_filter)
filters.append(_pinlock_filter)

View File

@ -1,7 +1,7 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor.wire import DataError
from trezor.wire import DataError, context
from .. import writers
@ -26,7 +26,7 @@ class PaymentRequestVerifier:
def __init__(
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
@ -42,9 +42,9 @@ class PaymentRequestVerifier:
if msg.nonce:
nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce:
if context.cache_get(APP_COMMON_NONCE) != nonce:
raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE)
context.cache_delete(APP_COMMON_NONCE)
else:
nonce = b""
if msg.memos:

View File

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from storage import cache, device
import storage.device as device
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_CARDANO_ICARUS_TREZOR_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
from trezor import wire
from trezor.crypto import cardano
@ -15,6 +20,7 @@ if TYPE_CHECKING:
from trezor import messages
from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path
@ -110,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None:
def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized()
assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO)
assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed
@ -132,14 +138,15 @@ def derive_and_store_secrets(passphrase: str) -> None:
else:
icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType
from trezor.wire import context
from apps.common.seed import derive_and_store_roots
from apps.common.seed import derive_and_store_roots_legacy
if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
@ -148,19 +155,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed))
if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO):
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
raise wire.ProcessError("Cardano derivation is not enabled for this session")
if derivation_type == CardanoDerivationType.ICARUS:
cache_entry = cache.APP_CARDANO_ICARUS_SECRET
cache_entry = APP_CARDANO_ICARUS_SECRET
else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET
# _get_secret
secret = cache.get(cache_entry)
secret = context.cache_get(cache_entry)
if secret is None:
await derive_and_store_roots()
secret = cache.get(cache_entry)
await derive_and_store_roots_legacy()
secret = context.cache_get(cache_entry)
assert secret is not None
root = cardano.from_secret(secret)

View File

@ -1,23 +1,21 @@
from typing import Iterable
import storage.cache as storage_cache
from storage.cache_common import (
APP_COMMON_AUTHORIZATION_DATA,
APP_COMMON_AUTHORIZATION_TYPE,
)
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire import context
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool:
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None
return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
def set(auth_message: protobuf.MessageType) -> None:
@ -29,16 +27,16 @@ def set(auth_message: protobuf.MessageType) -> None:
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, stored_auth_type)
@ -49,7 +47,7 @@ def is_set_any_session(auth_type: MessageType) -> bool:
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None:
return ()
@ -57,5 +55,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None:
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE)
context.cache_delete(APP_COMMON_AUTHORIZATION_DATA)

View File

@ -1,25 +1,27 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_common as storage_cache
from trezor import wire
from trezor.enums import MessageType
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
if TYPE_CHECKING:
from trezor.wire import Handler, Msg
def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
return context.cache_get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup():
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter)
context.cache_set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
filters.append(_repeated_backup_filter)
def deactivate_repeated_backup():
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter)
context.cache_delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (

View File

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from trezor.wire import context
if TYPE_CHECKING:
from typing import Callable, ParamSpec
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = context.cache_get(key)
if value is None:
value = func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator

View File

@ -1,15 +1,48 @@
from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from trezor.wire import DataError
_MAX_PASSPHRASE_LEN = const(50)
if TYPE_CHECKING:
from trezor.messages import ThpCreateNewSession
def is_enabled() -> bool:
return storage_device.is_passphrase_enabled()
async def get_passphrase(msg: ThpCreateNewSession) -> str:
if not is_enabled():
return ""
if msg.on_device or storage_device.get_passphrase_always_on_device():
passphrase = await _get_on_device()
else:
passphrase = msg.passphrase or ""
if passphrase:
await _handle_displaying_passphrase_from_host(passphrase)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _get_on_device() -> str:
from trezor import workflow
from trezor.ui.layouts import request_passphrase_on_device
workflow.close_others() # request exclusive UI access
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
return passphrase
@check_thp_is_not_used
async def get() -> str:
from trezor import workflow
@ -29,8 +62,8 @@ async def get() -> str:
return passphrase
@check_thp_is_not_used
async def _request_on_host() -> str:
from trezor import TR
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
@ -55,29 +88,34 @@ async def _request_on_host() -> str:
# non-empty passphrase
if passphrase:
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)
await _handle_displaying_passphrase_from_host(passphrase)
return passphrase
async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
from trezor import TR
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)

View File

@ -1,9 +1,10 @@
import utime
from typing import Any, NoReturn
import storage.cache as storage_cache
from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK
from trezor import TR, config, utils, wire
from trezor.ui.layouts import show_error_and_raise
from trezor.wire import context
async def _request_sd_salt(
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None:
now = utime.ticks_ms()
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
_DEF_ARG_PIN_ENTER: str = TR.pin__enter
@ -91,7 +92,7 @@ async def verify_user_pin(
) -> None:
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if (

View File

@ -1,15 +1,15 @@
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
def read_setting() -> SafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

View File

@ -1,18 +1,33 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import check_thp_is_not_used
from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE
from trezor import utils
from trezor.crypto import hmac
from trezor.wire import context
from trezor.wire.context import get_context
from trezor.wire.errors import DataError
from apps.common import cache
from . import mnemonic
from .passphrase import get as get_passphrase
from .passphrase import get as get_passphrase_legacy
from .passphrase import get_passphrase as get_passphrase
if TYPE_CHECKING:
from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path
if not utils.BITCOIN_ONLY:
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
class Slip21Node:
"""
@ -45,54 +60,89 @@ class Slip21Node:
return Slip21Node(data=self.data)
if not utils.BITCOIN_ONLY:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def get_seed() -> bytes:
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
if utils.BITCOIN_ONLY:
# === Bitcoin_only variant ===
# We want to derive the normal seed ONLY
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
if ctx.cache.is_set(APP_COMMON_SEED):
raise Exception("Seed is already set!")
async def derive_and_store_roots() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = storage_cache.get_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
else:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
raise Exception("Cardano icarus secret is already set!")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if msg.derive_cardano:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(ctx, passphrase)
@check_thp_is_not_used
async def derive_and_store_roots_legacy() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
ctx = get_context()
need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
need_cardano_secret = ctx.cache.get_bool(
APP_COMMON_DERIVE_CARDANO
) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase()
passphrase = await get_passphrase_legacy()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)
derive_and_store_secrets(ctx, passphrase)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage_device.is_initialized():
raise Exception("Device is not initialized")

View File

@ -34,7 +34,7 @@ if __debug__:
layout_change_chan = loop.mailbox()
DEBUG_CONTEXT: context.Context | None = None
DEBUG_CONTEXT: context.CodecContext | None = None
REFRESH_INDEX = 0
@ -441,4 +441,4 @@ if __debug__:
def boot() -> None:
import usb
loop.schedule(handle_session(usb.iface_debug))#
loop.schedule(handle_session(usb.iface_debug)) #

View File

@ -5,10 +5,11 @@ if TYPE_CHECKING:
async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto import random
from trezor.messages import Nonce
from trezor.wire.context import cache_set
nonce = random.bytes(32)
cache.set(cache.APP_COMMON_NONCE, nonce)
cache_set(APP_COMMON_NONCE, nonce)
return Nonce(nonce=nonce)

View File

@ -5,6 +5,7 @@ import storage.recovery as storage_recovery
import storage.recovery_shares as storage_recovery_shares
from trezor import TR, wire
from trezor.messages import Success
from trezor.wire import message_handler
from apps.common import backup_types
@ -36,7 +37,7 @@ async def recovery_process() -> Success:
recovery_type = storage_recovery.get_type()
wire.AVOID_RESTARTING_FOR = (
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
@ -57,7 +58,7 @@ async def _continue_repeated_backup() -> None:
from apps.common import backup
from apps.management.backup_device import perform_backup
wire.AVOID_RESTARTING_FOR = (
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,

View File

@ -59,14 +59,15 @@ async def _init_step(
) -> MoneroLiveRefreshStartAck:
import storage.cache as storage_cache
from trezor.messages import MoneroLiveRefreshStartAck
from trezor.wire import context
from apps.common import paths
await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
if not context.cache_get_bool(storage_cache.APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh()
storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True)
context.cache_set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

View File

@ -0,0 +1,51 @@
from trezor import log, loop
from trezor.enums import FailureType
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.wire.context import get_context
from trezor.wire.errors import ActionCancelled, DataError
from trezor.wire.thp import SessionState
async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
ctx = get_context()
# Assert that context `ctx` is ManagementSessionContext
from trezor.wire.thp.session_context import ManagementSessionContext
assert isinstance(ctx, ManagementSessionContext)
channel = ctx.channel
# Do not use `ctx` beyond this point, as it is techically
# allowed to change inbetween await statements
new_session = create_new_session(channel)
try:
await derive_and_store_roots(new_session, message)
except DataError as e:
return Failure(code=FailureType.DataError, message=e.message)
except ActionCancelled as e:
return Failure(code=FailureType.ActionCancelled, message=e.message)
# TODO handle other errors
# TODO handle BITCOIN_ONLY
new_session.set_session_state(SessionState.ALLOCATED)
channel.sessions[new_session.session_id] = new_session
loop.schedule(new_session.handle())
new_session_id: int = new_session.session_id
# await get_seed() TODO
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
message.passphrase if message.passphrase is not None else "",
new_session.session_id,
str(channel.sessions),
)
return ThpNewSession(new_session_id=new_session_id)

View File

@ -0,0 +1,110 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.crypto import hmac
from trezor.messages import (
ThpAuthenticatedCredentialData,
ThpCredentialMetadata,
ThpPairingCredential,
)
from trezor.wire import message_handler
if TYPE_CHECKING:
from apps.common.paths import Slip21Path
def derive_cred_auth_key() -> bytes:
"""
Derive current credential authentication mac-ing key from device secret.
"""
from storage.device import get_cred_auth_key_counter, get_device_secret
from apps.common.seed import Slip21Node
# Derive the key using SLIP-21 https://github.com/satoshilabs/slips/blob/master/slip-0021.md,
# the derivation path is m/"Credential authentication key"/(counter 4-byte BE)
thp_secret = get_device_secret()
label = b"Credential authentication key"
counter = get_cred_auth_key_counter()
path: Slip21Path = [label, counter]
symmetric_key_node: Slip21Node = Slip21Node(thp_secret)
symmetric_key_node.derive_path(path)
cred_auth_key = symmetric_key_node.key()
return cred_auth_key
def invalidate_cred_auth_key() -> None:
from storage.device import increment_cred_auth_key_counter
increment_cred_auth_key_counter()
def issue_credential(
host_static_pubkey: bytes,
credential_metadata: ThpCredentialMetadata,
) -> bytes:
"""
Issue a pairing credential binded to the provided host static public key
and credential metadata.
"""
return _issue_credential(
derive_cred_auth_key(), host_static_pubkey, credential_metadata
)
def validate_credential(
encoded_pairing_credential_message: bytes,
host_static_pubkey: bytes,
) -> bool:
"""
Validate a pairing credential binded to the provided host static public key.
"""
return _validate_credential(
derive_cred_auth_key(), encoded_pairing_credential_message, host_static_pubkey
)
def _issue_credential(
cred_auth_key: bytes,
host_static_pubkey: bytes,
credential_metadata: ThpCredentialMetadata,
) -> bytes:
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,
cred_metadata=credential_metadata,
)
authenticated_credential_data = _encode_message_into_new_buffer(proto_msg)
mac = hmac(hmac.SHA256, cred_auth_key, authenticated_credential_data).digest()
proto_msg = ThpPairingCredential(cred_metadata=credential_metadata, mac=mac)
credential_raw = _encode_message_into_new_buffer(proto_msg)
return credential_raw
def _validate_credential(
cred_auth_key: bytes,
encoded_pairing_credential_message: bytes,
host_static_pubkey: bytes,
) -> bool:
expected_type = protobuf.type_for_name("ThpPairingCredential")
credential = message_handler.wrap_protobuf_load(
encoded_pairing_credential_message, expected_type
)
assert ThpPairingCredential.is_type_of(credential)
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,
cred_metadata=credential.cred_metadata,
)
authenticated_credential_data = _encode_message_into_new_buffer(proto_msg)
mac = hmac(hmac.SHA256, cred_auth_key, authenticated_credential_data).digest()
return mac == credential.mac
def _encode_message_into_new_buffer(msg: protobuf.MessageType) -> bytes:
msg_len = protobuf.encoded_length(msg)
new_buffer = bytearray(msg_len)
protobuf.encode(new_buffer, msg)
return new_buffer

View File

@ -0,0 +1,405 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import loop, protobuf
from trezor.crypto.hashlib import sha256
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
Cancel,
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpCredentialMetadata,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnidirectionalSecret,
ThpNfcUnidirectionalTag,
ThpPairingPreparationsFinished,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto
from trezor.wire.thp.pairing_context import PairingContext
from .credential_manager import issue_credential
if __debug__:
from trezor import log
if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
#
# Helpers - decorators
def check_state_and_log(
*allowed_states: ChannelState,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_state(context, *allowed_states)
if __debug__:
try:
log.debug(__name__, "started %s", f.__name__)
except AttributeError:
log.debug(
__name__,
"started a function that cannot be named, because it raises AttributeError, eg. closure",
)
return f(context, *args, **kwargs)
return inner
return decorator
def check_method_is_allowed(
pairing_method: ThpPairingMethod,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_method_is_allowed(context, pairing_method)
return f(context, *args, **kwargs)
return inner
return decorator
#
# Pairing handlers
@check_state_and_log(ChannelState.TP1)
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpStartPairingRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
ctx.host_name = message.host_name or ""
skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
if skip_pairing:
return await _end_pairing(ctx)
await _prepare_pairing(ctx)
await ctx.write(ThpPairingPreparationsFinished())
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await show_display_data(
ctx, _get_possible_pairing_methods_and_cancel(ctx)
)
if __debug__:
from trezor.messages import DebugLinkGetState
while DebugLinkGetState.is_type_of(response):
from apps.debug import dispatch_DebugLinkGetState
dl_state = await dispatch_DebugLinkGetState(response)
assert dl_state is not None
await ctx.write(dl_state)
response = await show_display_data(
ctx, _get_possible_pairing_methods_and_cancel(ctx)
)
if Cancel.is_type_of(response):
ctx.channel_ctx.clear()
raise SilentError("Action was cancelled by the Host")
# TODO disable NFC (if enabled)
response = await _handle_different_pairing_methods(ctx, response)
while ThpCredentialRequest.is_type_of(response):
response = await _handle_credential_request(ctx, response)
return await _handle_end_request(ctx, response)
async def _prepare_pairing(ctx: PairingContext) -> None:
if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
await _handle_code_entry_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.QrCode):
_handle_qr_code_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
_handle_nfc_unidirectional_is_included(ctx)
async def show_display_data(ctx: PairingContext, expected_types: Container[int] = ()):
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task)
result = await race
if read_task in race.finished:
return result
if cancel_task in race.finished:
raise ActionCancelled
else:
return Exception("Should not happen") # TODO what to do here?
@check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
commitment = sha256(ctx.secret).digest()
challenge_message = await ctx.call( # noqa: F841
ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
)
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
if not ThpCodeEntryChallenge.is_type_of(challenge_message):
raise UnexpectedMessage("Unexpected message")
if challenge_message.challenge is None:
raise Exception("Invalid message")
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(challenge_message.challenge)
sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8"))
code_code_entry_hash = sha_ctx.digest()
ctx.display_data.code_code_entry = (
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_qr_code_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8"))
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8"))
ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP3)
async def _handle_different_pairing_methods(
ctx: PairingContext, response: protobuf.MessageType
) -> protobuf.MessageType:
if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise UnexpectedMessage("Unexpected message")
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
from trezor.wire.thp.cpace import Cpace
# TODO check that ThpCodeEntryCpaceHost message is valid
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryCpaceHost)
if message.cpace_host_public_key is None:
raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
ctx.cpace = Cpace(
message.cpace_host_public_key,
ctx.channel_ctx.get_handshake_hash(),
)
assert ctx.display_data.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
ThpCodeEntryTag,
)
return await _handle_code_entry_tag(ctx, response)
@check_state_and_log(ChannelState.TP4)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryTag)
expected_tag = sha256(ctx.cpace.shared_secret).digest()
if expected_tag != message.tag:
print(
"expected code entry tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code entry shared secret:",
hexlify(ctx.cpace.shared_secret).decode(),
) # TODO remove after testing
raise ThpError("Unexpected Code Entry Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpCodeEntrySecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.QrCode)
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag)
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
if expected_tag != message.tag:
print(
"expected qr code tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected QR Code Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpQrCodeSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpNfcUnidirectionalTag)
expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
if expected_tag != message.tag:
print(
"expected nfc tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected NFC Unidirectional Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
async def _handle_secret_reveal(
ctx: PairingContext,
msg: protobuf.MessageType,
) -> protobuf.MessageType:
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
return await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
ctx.secret
if not ThpCredentialRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
if message.host_static_pubkey is None:
raise Exception("Invalid message") # TODO change failure type
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
return await ctx.call_any(
ThpCredentialResponse(
trezor_static_pubkey=trezor_static_pubkey, credential=credential
),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpEndRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
return await _end_pairing(ctx)
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
#
# Helpers - checkers
def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
if ctx.channel_ctx.get_channel_state() not in allowed_states:
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel_ctx.selected_pairing_methods
#
# Helpers - getters
def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]:
r = _get_possible_pairing_methods(ctx)
mtype = Cancel.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
r = tuple(
_get_message_type_for_method(method)
for method in ctx.channel_ctx.selected_pairing_methods
)
if __debug__:
from trezor.messages import DebugLinkGetState
mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
return r
def _get_message_type_for_method(method: int) -> int:
if method is ThpPairingMethod.CodeEntry:
return MessageType.ThpCodeEntryCpaceHost
if method is ThpPairingMethod.NFC_Unidirectional:
return MessageType.ThpNfcUnidirectionalTag
if method is ThpPairingMethod.QrCode:
return MessageType.ThpQrCodeTag
raise ValueError("Unexpected pairing method - no message type available")

View File

@ -375,7 +375,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
desc_cont = frame_cont()
read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for incoming comand indefinitely
# wait for incoming command indefinitely
buf = await read
while True:
ifrm = overlay_struct(bytearray(buf), desc_init)

View File

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezorio import WireInterface
from trezor.wire import Handler, Msg
@ -207,7 +205,7 @@ def _find_message_handler_module(msg_type: int) -> str:
raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
def find_registered_handler(msg_type: int) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
return workflow_handlers[msg_type]

View File

@ -108,8 +108,9 @@ if not utils.USE_OPTIGA or (optiga.get_sec() or 0) < 150:
config.init(show_pin_timeout)
translations.init()
if __debug__ and not utils.EMULATOR:
config.wipe()
# TODO return after testing
# if __debug__ and not utils.EMULATOR:
# config.wipe()
loop.schedule(bootscreen())
loop.run()

View File

@ -1,153 +1,10 @@
import builtins
import gc
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int] # field sizes
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
@ -156,210 +13,61 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def _get_length(key: int) -> int:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
elif _active_session_idx is None:
raise InvalidSessionError
else:
return _SESSIONS[_active_session_idx].fields[key]
def set_int(key: int, value: int) -> None:
length = _get_length(key)
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
def set_bool(key: int, value: bool) -> None:
assert _get_length(key) == 0 # skipping get_length in production build
if value:
set(key, b"")
else:
delete(key)
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
encoded = get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def get_bool(key: int) -> bool: # noqa: F811
return get(key) is not None
def clear_all() -> None:
global autolock_last_touch
autolock_last_touch = None
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
values = builtins.set()
for session in sessions:
encoded = session.get(key)
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
def get_sessionless_cache() -> SessionlessCache:
return _SESSIONLESS_CACHE
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec, TypeVar
from typing import Callable, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]:
"""A type-safe decorator to raise an exception when the function is called with THP enabled.
return wrapper
This decorator should be removed after the caches for Codec_v1 and THP are properly refactored and separated.
"""
return decorator
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if utils.USE_THP:
raise Exception("Cannot call this function with the new THP enabled")
return f(*args, **kwargs)
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None
return inner

View File

@ -0,0 +1,142 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

View File

@ -0,0 +1,179 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
APP_COMMON_AUTHORIZATION_DATA = const(2)
APP_COMMON_NONCE = const(3)
if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO = const(4)
APP_CARDANO_ICARUS_SECRET = const(5)
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7)
# Cache keys for THP channel
if utils.USE_THP:
CHANNEL_HANDSHAKE_HASH = const(0)
CHANNEL_KEY_RECEIVE = const(1)
CHANNEL_KEY_SEND = const(2)
CHANNEL_NONCE_RECEIVE = const(3)
CHANNEL_NONCE_SEND = const(4)
# Keys that are valid across sessions
SESSIONLESS_FLAG = const(128)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG)
APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
BUSYSCREEN_ON = object()
homescreen_shown: object | None = None
# Timestamp of last autolock activity.
# Here to persist across main loop restart between workflows.
autolock_last_touch: int | None = None
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def get_bool(self, key: int) -> bool: # noqa: F811
return self.get(key) is not None
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
encoded = self.get(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
def set_bool(self, key: int, value: bool) -> None:
utils.ensure(
self._get_length(key) == 0, "Field does not have zero length!"
) # skipping get_length in production build
if value:
self.set(key, b"")
else:
self.delete(key)
def set_int(self, key: int, value: int) -> None:
length = self.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
self.set(key, encoded)
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
def _get_length(self, key: int) -> int:
utils.ensure(key < len(self.fields))
return self.fields[key]
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
)
super().__init__()
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
return super().get(key & ~SESSIONLESS_FLAG, default)
def get_bool(self, key: int) -> bool: # noqa: F811
return super().get_bool(key & ~SESSIONLESS_FLAG)
def get_int(
self, key: int, default: T | None = None
) -> int | T | None: # noqa: F811
return super().get_int(key & ~SESSIONLESS_FLAG, default)
def is_set(self, key: int) -> bool:
return super().is_set(key & ~SESSIONLESS_FLAG)
def set(self, key: int, value: bytes) -> None:
super().set(key & ~SESSIONLESS_FLAG, value)
def set_bool(self, key: int, value: bool) -> None:
super().set_bool(key & ~SESSIONLESS_FLAG, value)
def set_int(self, key: int, value: int) -> None:
super().set_int(key & ~SESSIONLESS_FLAG, value)
def delete(self, key: int) -> None:
super().delete(key & ~SESSIONLESS_FLAG)
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)

View File

@ -0,0 +1,334 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
T = TypeVar("T")
if __debug__:
from trezor import log
# THP specific constants
_MAX_CHANNELS_COUNT = const(10)
_MAX_SESSIONS_COUNT = const(20)
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2)
SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(0xFFFF)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
MANAGEMENT_SESSION_ID = const(0)
class ConnectionCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.last_usage = 0
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
self.last_usage = 0
super().clear()
class ChannelCache(ConnectionCache):
def __init__(self) -> None:
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x00
self.fields = (
32, # CHANNEL_HANDSHAKE_HASH
32, # CHANNEL_KEY_RECEIVE
32, # CHANNEL_KEY_SEND
8, # CHANNEL_NONCE_RECEIVE
8, # CHANNEL_NONCE_SEND
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
) # Set state to UNALLOCATED
# TODO clear all keys
super().clear()
class SessionThpCache(ConnectionCache):
def __init__(self) -> None:
self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.session_id[:] = b""
super().clear()
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
def initialize() -> None:
global _CHANNELS
global _SESSIONS
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
# First unauthenticated channel will have index 0
_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659 # TODO change to random value on start
def get_new_channel(iface: bytes) -> ChannelCache:
if len(iface) != _WIRE_INTERFACE_LENGTH:
raise Exception("Invalid WireInterface (encoded) length")
new_cid = get_next_channel_id()
index = _get_next_unauthenticated_channel_index()
# clear sessions from replaced channel
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
old_cid = _CHANNELS[index].channel_id
clear_sessions_with_channel_id(old_cid)
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface[:] = bytearray(iface)
return _CHANNELS[index]
def update_channel_last_used(channel_id):
for channel in _CHANNELS:
if channel.channel_id == channel_id:
channel.last_usage = _get_usage_counter_and_increment()
return
def update_session_last_used(channel_id, session_id):
for session in _SESSIONS:
if session.channel_id == channel_id and session.session_id == session_id:
session.last_usage = _get_usage_counter_and_increment()
update_channel_last_used(channel_id)
return
def get_all_allocated_channels() -> list[ChannelCache]:
_list: list[ChannelCache] = []
for channel in _CHANNELS:
if _get_channel_state(channel) != _UNALLOCATED_STATE:
_list.append(channel)
return _list
def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]:
if __debug__:
from trezor.utils import get_bytes_as_str
_list: list[SessionThpCache] = []
for session in _SESSIONS:
if _get_session_state(session) == _UNALLOCATED_STATE:
continue
if session.channel_id != channel_id:
continue
_list.append(session)
if __debug__:
log.debug(
__name__,
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
get_bytes_as_str(session.channel_id),
get_bytes_as_str(session.session_id),
)
return _list
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
if len(key) != KEY_LENGTH:
raise Exception("Invalid key length")
channel.host_ephemeral_pubkey = key
def get_new_session(channel: ChannelCache):
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
_SESSIONS[index] = SessionThpCache()
_SESSIONS[index].channel_id[:] = channel.channel_id
_SESSIONS[index].session_id[:] = new_sid
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
channel.last_usage = (
_get_usage_counter_and_increment()
) # increment also use of the channel so it does not get replaced
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
return _SESSIONS[index]
def _get_usage_counter() -> int:
global _usage_counter
return _usage_counter
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
return _usage_counter
def _get_next_unauthenticated_channel_index() -> int:
idx = _get_unallocated_channel_index()
if idx is not None:
return idx
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_next_session_index() -> int:
idx = _get_unallocated_session_index()
if idx is not None:
return idx
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_unallocated_session_index() -> int | None:
for i in range(_MAX_SESSIONS_COUNT):
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
return int.from_bytes(channel.state, "big")
def _get_session_state(session: SessionThpCache) -> int:
return int.from_bytes(session.state, "big")
def get_next_channel_id() -> bytes:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def get_next_session_id(channel: ChannelCache) -> bytes:
while True:
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
for session in _SESSIONS:
if session.channel_id == channel.channel_id:
if session.session_id == channel.session_id_counter:
return False
return True
def _is_cid_unique() -> bool:
for session in _SESSIONS:
if cid_counter == _get_cid(session):
return False
return True
def _get_cid(session: SessionThpCache) -> int:
return int.from_bytes(session.session_id[2:], "big")
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
):
lru_counter = _get_usage_counter()
lru_item_index = 0
for i in range(max_count):
if list[i].last_usage < lru_counter:
lru_counter = list[i].last_usage
lru_item_index = i
return lru_item_index
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_sessions_with_channel_id(channel_id: bytes):
for session in _SESSIONS:
if session.channel_id == channel_id:
session.clear()
def clear_all() -> None:
for session in _SESSIONS:
session.clear()
for channel in _CHANNELS:
channel.clear()

View File

@ -16,4 +16,5 @@ NotInitialized = 11
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
FirmwareError = 99

View File

@ -95,6 +95,24 @@ DebugLinkEraseSdCard = 9005
DebugLinkWatchLayout = 9006
DebugLinkResetDebugEvents = 9007
DebugLinkOptigaSetSecMax = 9008
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
if not utils.BITCOIN_ONLY:
SetU2FCounter = 63
GetNextU2FCounter = 80

View File

@ -0,0 +1,8 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

View File

@ -262,6 +262,24 @@ if TYPE_CHECKING:
SolanaAddress = 903
SolanaSignTx = 904
SolanaTxSignature = 905
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class FailureType(IntEnum):
UnexpectedMessage = 1
@ -278,6 +296,7 @@ if TYPE_CHECKING:
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
FirmwareError = 99
class ButtonRequestType(IntEnum):
@ -575,3 +594,9 @@ if TYPE_CHECKING:
Yay = 0
Nay = 1
Pass = 2
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

View File

@ -66,6 +66,7 @@ if TYPE_CHECKING:
from trezor.enums import StellarSignerType # noqa: F401
from trezor.enums import TezosBallotType # noqa: F401
from trezor.enums import TezosContractType # noqa: F401
from trezor.enums import ThpPairingMethod # noqa: F401
from trezor.enums import WordRequestType # noqa: F401
class BinanceGetAddress(protobuf.MessageType):
@ -2814,11 +2815,13 @@ if TYPE_CHECKING:
class DebugLinkGetState(protobuf.MessageType):
wait_layout: "DebugWaitType"
thp_channel_id: "int | None"
def __init__(
self,
*,
wait_layout: "DebugWaitType | None" = None,
thp_channel_id: "int | None" = None,
) -> None:
pass
@ -2840,6 +2843,8 @@ if TYPE_CHECKING:
reset_word_pos: "int | None"
mnemonic_type: "BackupType | None"
tokens: "list[str]"
thp_pairing_code_entry_code: "int | None"
thp_pairing_secret: "bytes | None"
def __init__(
self,
@ -2857,6 +2862,8 @@ if TYPE_CHECKING:
recovery_word_pos: "int | None" = None,
reset_word_pos: "int | None" = None,
mnemonic_type: "BackupType | None" = None,
thp_pairing_code_entry_code: "int | None" = None,
thp_pairing_secret: "bytes | None" = None,
) -> None:
pass
@ -6078,6 +6085,324 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
return isinstance(msg, cls)
class ThpDeviceProperties(protobuf.MessageType):
internal_model: "str | None"
model_variant: "int | None"
bootloader_mode: "bool | None"
protocol_version: "int | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
internal_model: "str | None" = None,
model_variant: "int | None" = None,
bootloader_mode: "bool | None" = None,
protocol_version: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
return isinstance(msg, cls)
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
host_pairing_credential: "bytes | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
host_pairing_credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
return isinstance(msg, cls)
class ThpCreateNewSession(protobuf.MessageType):
passphrase: "str | None"
on_device: "bool | None"
derive_cardano: "bool | None"
def __init__(
self,
*,
passphrase: "str | None" = None,
on_device: "bool | None" = None,
derive_cardano: "bool | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
return isinstance(msg, cls)
class ThpNewSession(protobuf.MessageType):
new_session_id: "int | None"
def __init__(
self,
*,
new_session_id: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
return isinstance(msg, cls)
class ThpStartPairingRequest(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
return isinstance(msg, cls)
class ThpPairingPreparationsFinished(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
return isinstance(msg, cls)
class ThpCodeEntryCommitment(protobuf.MessageType):
commitment: "bytes | None"
def __init__(
self,
*,
commitment: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
return isinstance(msg, cls)
class ThpCodeEntryChallenge(protobuf.MessageType):
challenge: "bytes | None"
def __init__(
self,
*,
challenge: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceHost(protobuf.MessageType):
cpace_host_public_key: "bytes | None"
def __init__(
self,
*,
cpace_host_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
cpace_trezor_public_key: "bytes | None"
def __init__(
self,
*,
cpace_trezor_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
return isinstance(msg, cls)
class ThpCodeEntryTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
return isinstance(msg, cls)
class ThpCodeEntrySecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
return isinstance(msg, cls)
class ThpQrCodeTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
return isinstance(msg, cls)
class ThpQrCodeSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
return isinstance(msg, cls)
class ThpCredentialRequest(protobuf.MessageType):
host_static_pubkey: "bytes | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
return isinstance(msg, cls)
class ThpCredentialResponse(protobuf.MessageType):
trezor_static_pubkey: "bytes | None"
credential: "bytes | None"
def __init__(
self,
*,
trezor_static_pubkey: "bytes | None" = None,
credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
return isinstance(msg, cls)
class ThpEndRequest(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
return isinstance(msg, cls)
class ThpEndResponse(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
return isinstance(msg, cls)
class ThpCredentialMetadata(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialMetadata"]:
return isinstance(msg, cls)
class ThpPairingCredential(protobuf.MessageType):
cred_metadata: "ThpCredentialMetadata | None"
mac: "bytes | None"
def __init__(
self,
*,
cred_metadata: "ThpCredentialMetadata | None" = None,
mac: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingCredential"]:
return isinstance(msg, cls)
class ThpAuthenticatedCredentialData(protobuf.MessageType):
host_static_pubkey: "bytes | None"
cred_metadata: "ThpCredentialMetadata | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
cred_metadata: "ThpCredentialMetadata | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpAuthenticatedCredentialData"]:
return isinstance(msg, cls)
class WebAuthnListResidentCredentials(protobuf.MessageType):
@classmethod

View File

@ -35,6 +35,8 @@ from typing import TYPE_CHECKING
DISABLE_ANIMATION = 0
DISABLE_ENCRYPTION: bool = False
if __debug__:
if EMULATOR:
import uos
@ -45,7 +47,13 @@ if __debug__:
LOG_MEMORY = 0
if TYPE_CHECKING:
from typing import Any, Iterator, Protocol, Sequence, TypeVar
from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Iterator,
Protocol,
Sequence,
TypeVar,
)
from trezor.protobuf import MessageType
@ -111,6 +119,7 @@ def presize_module(modname: str, size: int) -> None:
if __debug__:
from ubinascii import hexlify
def mem_dump(filename: str) -> None:
from micropython import mem_info
@ -127,6 +136,9 @@ if __debug__:
else:
mem_info(True)
def get_bytes_as_str(a):
return hexlify(a).decode("utf-8")
def ensure(cond: bool, msg: str | None = None) -> None:
if not cond:

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_main.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
@ -16,22 +16,28 @@ This module:
## Session handler
When the `wire.setup` is called the `handle_session` coroutine is scheduled. The
When the `wire.setup` is called the `handle_session` (or `handle_thp_session`) coroutine is scheduled. The
`handle_session` waits for some messages to be received on some particular interface and
reads the message's header. When the message type is known the first handler is called. This way the
`handle_session` goes through all the workflows.
"""
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor import log, loop, protobuf, utils
from trezor.wire import context, message_handler, protocol_common
if utils.USE_THP:
from trezor.wire import thp_main
from trezor.wire.message_handler import WIRE_BUFFER_2
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import (
WIRE_BUFFER,
WIRE_BUFFER_DEBUG,
failure,
find_handler,
)
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -40,12 +46,11 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import Any, Callable, Coroutine, TypeVar
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -53,147 +58,57 @@ if TYPE_CHECKING:
EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID))
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed WireInterface."""
if utils.USE_THP and not is_debug_session:
loop.schedule(handle_thp_session(iface, is_debug_session))
else:
loop.schedule(handle_session(iface, is_debug_session))
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
if utils.USE_THP:
async def handle_thp_session(iface: WireInterface, is_debug_session: bool = False):
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
raise DataError("Failed to decode message")
ctx_buffer = WIRE_BUFFER
thp_main.set_read_buffer(ctx_buffer)
thp_main.set_write_buffer(WIRE_BUFFER_2)
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
await thp_main.thp_main_loop(iface, is_debug_session)
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
utils.unimport_end(modules)
loop.clear()
return
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def _handle_single_message(
ctx: context.Context, msg: codec_v1.Message
) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
log.debug(
__name__,
"%s:%x receive: <%s>",
ctx.iface.iface_num(),
ctx.sid,
msg_type,
)
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type.
try:
handler = find_handler(ctx.iface, msg.type)
except Error as exc:
# Handlers are allowed to exception out. In that case, we can skip decoding
# and return the error.
await ctx.write(failure(exc))
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await workflow.spawn(context.with_context(ctx, task))
except context.UnexpectedMessage:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
#
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
async def handle_session(iface: WireInterface, session_id: int) -> None:
ctx = context.Context(iface, session_id, WIRE_BUFFER)
next_msg: codec_v1.Message | None = None
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
ctx = context.CodecContext(iface, ctx_buffer)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
@ -205,7 +120,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc:
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
@ -218,8 +133,10 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
do_not_restart = False
try:
do_not_restart = await _handle_single_message(ctx, msg)
except context.UnexpectedMessage as unexpected:
do_not_restart = await message_handler.handle_single_message(
ctx, msg, handler_finder=find_handler
)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
@ -245,81 +162,3 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
def find_handler(iface: WireInterface, msg_type: int) -> Handler:
import usb
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(iface, msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
if __debug__ and iface is usb.iface_debug:
# no filtering allowed for debuglink
return handler
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter):
try:
filters.remove(filter)
except ValueError:
pass
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")

View File

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
class CodecError(Exception):
class CodecError(WireError):
pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)

View File

@ -15,9 +15,12 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
from typing import TYPE_CHECKING
from storage import cache, cache_codec
from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError
from trezor import log, loop, protobuf
from trezor.wire import codec_v1
from . import codec_v1
from .protocol_common import Context, Message
if TYPE_CHECKING:
from trezorio import WireInterface
@ -32,6 +35,8 @@ if TYPE_CHECKING:
overload,
)
from storage.cache_common import DataCache
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
@ -41,31 +46,35 @@ if TYPE_CHECKING:
T = TypeVar("T")
class UnexpectedMessage(Exception):
class UnexpectedMessageException(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
def __init__(self, msg: Message) -> None:
super().__init__()
self.msg = msg
class Context:
class CodecContext(Context):
"""Wire context.
Represents USB communication inside a particular session on a particular interface
Represents USB communication inside a particular session (channel) on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big"))
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
@ -95,9 +104,8 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
"%s: expect: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
@ -107,7 +115,7 @@ class Context:
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
raise UnexpectedMessageException(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
@ -115,14 +123,14 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
"%s: read: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
from . import message_handler # noqa: F401
from .message_handler import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
@ -131,9 +139,8 @@ class Context:
if __debug__:
log.debug(
__name__,
"%s:%x write: %s",
"%s: write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
@ -150,23 +157,19 @@ class Context:
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
c = cache_codec.get_active_session()
if c is None:
raise InvalidSessionError()
return c
CURRENT_CONTEXT: Context | None = None
@ -273,3 +276,65 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = e
else:
send_exc = None
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get(key, default)
def cache_get_bool(key: int) -> bool: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_bool(key)
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
cache = _get_cache_for_key(key)
return cache.get_int(key, default)
def cache_is_set(key: int) -> bool:
cache = _get_cache_for_key(key)
return cache.is_set(key)
def cache_set(key: int, value: bytes) -> None:
cache = _get_cache_for_key(key)
cache.set(key, value)
def cache_set_bool(key: int, value: bool) -> None:
cache = _get_cache_for_key(key)
cache.set_bool(key, value)
def cache_set_int(key: int, value: int) -> None:
cache = _get_cache_for_key(key)
cache.set_int(key, value)
def cache_delete(key: int) -> None:
cache = _get_cache_for_key(key)
cache.delete(key)
def _get_cache_for_key(key) -> DataCache:
if key & SESSIONLESS_FLAG:
return cache.get_sessionless_cache()
if CURRENT_CONTEXT:
return CURRENT_CONTEXT.cache
raise Exception("No wire context")

View File

@ -8,6 +8,12 @@ class Error(Exception):
self.message = message
class SilentError(Exception):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message
class UnexpectedMessage(Error):
def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message)

View File

@ -0,0 +1,262 @@
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire.context import Context, UnexpectedMessageException, with_context
from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage
from trezor.wire.protocol_common import Message
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import Any, Callable, Container
from trezor.wire import Handler, LoadedMessageType
HandlerFinder = Callable[[Any], Handler | None]
# If set to False protobuf messages marked with "experimental_message" option are rejected.
EXPERIMENTAL_ENABLED = False
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
if __debug__:
log.debug(
__name__,
"Buffer to be parsed to a LoadedMessage: %s",
utils.get_bytes_as_str(buffer),
)
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
)
return msg
except Exception as e:
if __debug__:
log.exception(__name__, e)
if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args))
else:
raise DataError("Failed to decode message")
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_THP:
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
if __debug__:
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
async def handle_single_message(
ctx: Context,
msg: Message,
handler_finder: HandlerFinder,
) -> bool:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
a problem is encountered at any point, write the appropriate error on the wire.
The return value indicates whether to override the default restarting behavior. If
`False` is returned, the caller is allowed to clear the loop and restart the
MicroPython machine (see `session.py`). This would lose all state and incurs a cost
in terms of repeated startup time. When handling the message didn't cause any
significant fragmentation (e.g., if decoding the message was skipped), or if
the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`.
"""
if __debug__:
try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception:
msg_type = f"{msg.type} - unknown message type"
if ctx.channel_id is not None:
sid = int.from_bytes(ctx.channel_id, "big")
log.debug(
__name__,
"%s:%x receive: <%s>",
ctx.iface.iface_num(),
sid,
msg_type,
)
else:
log.debug(
__name__,
"%s:unknown_sid receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None
# # We need to find a handler for this message type.
# try:
# handler = find_handler(ctx.iface, msg.type)
# except Error as exc:
# # Handlers are allowed to exception out. In that case, we can skip decoding
# # and return the error.
# await ctx.write(failure(exc))
# return True
# We need to find a handler for this message type. Should not raise.
handler: Handler | None = handler_finder(msg.type)
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure.
await ctx.write(unexpected_message())
return True
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
except UnexpectedMessageException:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# We process the unexpected message by aborting the current workflow and
# possibly starting a new one, initiated by that message. (The main usecase
# being, the host does not finish the workflow, we want other callers to
# be able to do their own thing.)
#
# The message is stored in the exception, which we re-raise for the caller
# to process. It is not a standard exception that should be logged and a result
# sent to the wire.
raise
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
# Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting.
return msg.type in AVOID_RESTARTING_FOR
AVOID_RESTARTING_FOR: Container[int] = ()
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
def find_handler(msg_type: int) -> Handler:
from apps import workflow_handlers
handler = workflow_handlers.find_registered_handler(msg_type)
if handler is None:
raise UnexpectedMessage("Unexpected message")
for filter in filters:
handler = filter(msg_type, handler)
return handler
filters: list[Callable[[int, Handler], Handler]] = []
"""Filters for the wire handler.
Filters are applied in order. Each filter gets a message id and a preceding handler. It
must either return a handler (the same one or a modified one), or raise an exception
that gets sent to wire directly.
Filters are not applied to debug sessions.
The filters are designed for:
* rejecting messages -- while in Recovery mode, most messages are not allowed
* adding additional behavior -- while device is soft-locked, a PIN screen will be shown
before allowing a message to trigger its original behavior.
For this, the filters are effectively deny-first. If an earlier filter rejects the
message, the later filters are not called. But if a filter adds behavior, the latest
filter "wins" and the latest behavior triggers first.
Please note that this behavior is really unsuited to anything other than what we are
using it for now. It might be necessary to modify the semantics if we need more complex
usecases.
NB: `filters` is currently public so callers can have control over where they insert
new filters, but removal should be done using `remove_filter`!
We should, however, change it such that filters must be added using an `add_filter`
and `filters` becomes private!
"""
def remove_filter(filter):
try:
filters.remove(filter)
except ValueError:
pass

View File

@ -0,0 +1,70 @@
from typing import TYPE_CHECKING
from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Container, TypeVar, overload
from storage.cache_common import DataCache
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
T = TypeVar("T")
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
) -> None:
self.data = message_data
self.type = message_type
def to_bytes(self):
return self.type.to_bytes(2, "big") + self.data
class Context:
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ...
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
@property
def cache(self) -> DataCache: ...
class WireError(Exception):
pass

View File

@ -0,0 +1,80 @@
from typing import TYPE_CHECKING
from trezor.wire.protocol_common import WireError
class ThpError(WireError):
pass
class ThpDecryptionError(ThpError):
pass
class ThpInvalidDataError(ThpError):
pass
class ThpUnallocatedSessionError(ThpError):
def __init__(self, session_id: int):
self.session_id = session_id
if TYPE_CHECKING:
from enum import IntEnum
else:
IntEnum = object
class ThpErrorType(IntEnum):
TRANSPORT_BUSY = 1
UNALLOCATED_CHANNEL = 2
DECRYPTION_FAILED = 3
INVALID_DATA = 4
class ChannelState(IntEnum):
UNALLOCATED = 0
TH1 = 1
TH2 = 2
TP1 = 3
TP2 = 4
TP3 = 5
TP4 = 6
TC1 = 7
ENCRYPTED_TRANSPORT = 8
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
MANAGEMENT = 2
class WireInterfaceType(IntEnum):
MOCK = 0
USB = 1
BLE = 2
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

View File

@ -0,0 +1,102 @@
from storage.cache_thp import ChannelCache
from trezor import log
from trezor.wire.thp import ThpError
def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
"""
Checks if:
- an ACK message is expected
- the received ACK message acknowledges correct sequence number (bit)
"""
if not _is_ack_expected(cache):
return False
if not _has_ack_correct_sync_bit(cache, ack_bit):
return False
return True
def _is_ack_expected(cache: ChannelCache) -> bool:
is_expected: bool = not is_sending_allowed(cache)
if __debug__ and not is_expected:
log.debug(__name__, "Received unexpected ACK message")
return is_expected
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
is_correct: bool = get_send_seq_bit(cache) == sync_bit
if __debug__ and not is_correct:
log.debug(__name__, "Received ACK message with wrong ack bit")
return is_correct
def is_sending_allowed(cache: ChannelCache) -> bool:
"""
Checks whether sending a message in the provided channel is allowed.
Note: Sending a message in a channel before receipt of ACK message for the previously
sent message (in the channel) is prohibited, as it can lead to desynchronization.
"""
return bool(cache.sync >> 7)
def get_send_seq_bit(cache: ChannelCache) -> int:
"""
Returns the sequential number (bit) of the next message to be sent
in the provided channel.
"""
return (cache.sync & 0x20) >> 5
def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
"""
Returns the (expected) sequential number (bit) of the next message
to be received in the provided channel.
"""
return (cache.sync & 0x40) >> 6
def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
"""
Set the flag whether sending a message in this channel is allowed or not.
"""
cache.sync &= 0x7F
if sending_allowed:
cache.sync |= 0x80
def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
"""
Set the expected sequential number (bit) of the next message to be received
in the provided channel
"""
if __debug__:
log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
if seq_bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")
# set second bit to "seq_bit" value
cache.sync &= 0xBF
if seq_bit:
cache.sync |= 0x40
def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
if seq_bit not in (0, 1):
raise ThpError("Unexpected send seq bit")
if __debug__:
log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# set third bit to "seq_bit" value
cache.sync &= 0xDF
if seq_bit:
cache.sync |= 0x20
def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
"""
Set the sequential bit of the "next message to be send" to the opposite value,
i.e. 1 -> 0 and 0 -> 1
"""
_set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,391 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
from trezor import log, loop, protobuf, utils, workflow
from trezor.wire.thp.transmission_loop import TransmissionLoop
from . import ChannelState, ThpDecryptionError, ThpError
from . import alternating_bit_protocol as ABP
from . import (
control_byte,
crypto,
interface_manager,
memory_manager,
received_message_handler,
session_manager,
)
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, PacketHeader
from .writer import (
CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if __debug__:
from ubinascii import hexlify
from . import state_to_str
if TYPE_CHECKING:
from trezorio import WireInterface
from .pairing_context import PairingContext
from .session_context import GenericSessionContext
class Channel:
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__:
log.debug(__name__, "channel initialization")
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, GenericSessionContext] = {}
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
self.transmission_loop: TransmissionLoop | None = None
self.handshake: crypto.Handshake | None = None
self._create_management_session()
def clear(self):
clear_sessions_with_channel_id(self.channel_id)
self.channel_cache.clear()
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__:
log.debug(
__name__,
"(cid: %s) get_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
return state
def get_handshake_hash(self) -> bytes:
h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
assert h is not None
return h
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__:
log.debug(
__name__,
"(cid: %s) set_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__:
log.debug(
__name__,
"(cid: %s) set_buffer: %s",
utils.get_bytes_as_str(self.channel_id),
type(self.buffer),
)
def _create_management_session(self) -> None:
session = session_manager.create_new_management_session(self)
self.sessions[session.session_id] = session
loop.schedule(session.handle())
# CALLED BY THP_MAIN_LOOP
async def receive_packet(self, packet: utils.BufferType):
if __debug__:
log.debug(
__name__,
"(cid: %s) receive_packet",
utils.get_bytes_as_str(self.channel_id),
)
await self._handle_received_packet(packet)
if __debug__:
log.debug(
__name__,
"(cid: %s) self.buffer: %s",
utils.get_bytes_as_str(self.channel_id),
utils.get_bytes_as_str(self.buffer),
)
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message()
await received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
self.is_cont_packet_expected = True
else:
raise ThpError(
"Read more bytes than is the expected length of the message!"
)
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if control_byte.is_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(
__name__,
"(cid: %s) handle_init_packet",
utils.get_bytes_as_str(self.channel_id),
)
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
# TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
# On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
# in "memory_manager.select_buffer", because the session id is unknown (encrypted).
# if control_byte.is_encrypted_transport(ctrl_byte):
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(
__name__,
"(cid: %s) handle_init_packet - payload len: %d",
utils.get_bytes_as_str(self.channel_id),
payload_length,
)
log.debug(
__name__,
"(cid: %s) handle_init_packet - buffer len: %d",
utils.get_bytes_as_str(self.channel_id),
len(self.buffer),
)
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(
__name__,
"(cid: %s) handle_cont_packet",
utils.get_bytes_as_str(self.channel_id),
)
if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
def _decrypt_single_packet_payload(
self, payload: utils.BufferType
) -> utils.BufferType:
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload
def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None:
noise_buffer = memoryview(self.buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
tag = self.buffer[
message_length
- CHECKSUM_LENGTH
- TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
if utils.DISABLE_ENCRYPTION:
is_tag_valid = tag == crypto.DUMMY_TAG
else:
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
assert key_receive is not None
assert nonce_receive is not None
if __debug__:
log.debug(
__name__,
"(cid: %s) Buffer before decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
is_tag_valid = crypto.dec(
noise_buffer, tag, key_receive, nonce_receive, b""
)
if __debug__:
log.debug(
__name__,
"(cid: %s) Buffer after decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
if __debug__:
log.debug(
__name__,
"(cid: %s) Is decrypted tag valid? %s",
utils.get_bytes_as_str(self.channel_id),
str(is_tag_valid),
)
log.debug(
__name__,
"(cid: %s) Received tag: %s",
utils.get_bytes_as_str(self.channel_id),
(hexlify(tag).decode()),
)
log.debug(
__name__,
"(cid: %s) New nonce_receive: %i",
utils.get_bytes_as_str(self.channel_id),
nonce_receive + 1,
)
if not is_tag_valid:
raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__:
log.debug(
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
)
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
assert key_send is not None
assert nonce_send is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__:
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
async def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
):
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self):
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__ and utils.EMULATOR:
log.debug(
__name__,
"(cid: %s) write message: %s\n%s",
utils.get_bytes_as_str(self.channel_id),
msg.MESSAGE_NAME,
utils.dump_protobuf(msg),
)
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len])
async def write_error(self, err_type: int):
msg_data = err_type.to_bytes(1, "big")
length = len(msg_data) + CHECKSUM_LENGTH
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
await write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
async def write_and_encrypt(self, payload: bytes) -> None:
payload_length = len(payload)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
if self.write_task_spawn is not None:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
ABP.set_sending_allowed(self.channel_cache, False)
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__:
log.debug(
__name__,
"(cid %s) write_encrypted_payload_loop",
utils.get_bytes_as_str(self.channel_id),
)
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
self.transmission_loop = TransmissionLoop(self, header, payload)
await self.transmission_loop.start()
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
# Let the main loop be restarted and clear loop, if there is no other
# workflow and the state is ENCRYPTED_TRANSPORT
if self._can_clear_loop():
if __debug__:
log.debug(
__name__,
"(cid: %s) clearing loop from channel",
utils.get_bytes_as_str(self.channel_id),
)
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

View File

@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
"""
Creates a new channel for the interface `iface` with the buffer `buffer`.
"""
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
"""
Returns all allocated channels from cache.
"""
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

View File

@ -0,0 +1,22 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,47 @@
from trezor.wire.thp import ThpError
from trezor.wire.thp.thp_messages import (
ACK_MASK,
ACK_MESSAGE,
CONTINUATION_PACKET,
CONTINUATION_PACKET_MASK,
DATA_MASK,
ENCRYPTED_TRANSPORT,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
)
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise ThpError("Unexpected acknowledgement bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,36 @@
from trezor.crypto import elligator2, random
from trezor.crypto.curve import curve25519
from trezor.crypto.hashlib import sha512
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.host_public_key: bytes = cpace_host_public_key
self.shared_secret: bytes
self.trezor_private_key: bytes
self.trezor_public_key: bytes
def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = elligator2.map_to_curve25519(pregenerator)
self.trezor_private_key = random.bytes(32)
self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
self.shared_secret = curve25519.multiply(
self.trezor_private_key, self.host_public_key
)

View File

@ -0,0 +1,211 @@
from micropython import const
from trezorcrypto import aesgcm, bip32, curve25519, hmac
from storage import device
from trezor import log, utils
from trezor.crypto.hashlib import sha256
from trezor.wire.thp import ThpDecryptionError
# The HARDENED flag is taken from apps.common.paths
# It is not imported to save on resources
HARDENED = const(0x8000_0000)
PUBKEY_LENGTH = const(32)
if utils.DISABLE_ENCRYPTION:
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
if __debug__:
from ubinascii import hexlify
def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
"""
Encrypts the provided `buffer` with AES-GCM (in place).
Returns a 16-byte long encryption tag.
"""
if __debug__:
log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce)
iv = _get_iv_from_nonce(nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.encrypt_in_place(buffer)
return aes_ctx.finish()
def dec(
buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
) -> bool:
"""
Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
the tag computed in decryption, otherwise it returns `False`.
"""
iv = _get_iv_from_nonce(nonce)
if __debug__:
log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.decrypt_in_place(buffer)
computed_tag = aes_ctx.finish()
return computed_tag == tag
class BusyDecoder:
def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
iv = _get_iv_from_nonce(nonce)
self.aes_ctx = aesgcm(key, iv)
self.aes_ctx.auth(auth_data)
def decrypt_part(self, part: utils.BufferType) -> None:
self.aes_ctx.decrypt_in_place(part)
def finish_and_check_tag(self, tag: bytes) -> bool:
computed_tag = self.aes_ctx.finish()
return computed_tag == tag
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
class Handshake:
"""
`Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
The following values should be saved for future use before disposing of this object:
- `h` (handshake hash, can be used to bind other values to the channel)
- `key_receive` (key for decrypting incoming communication)
- `key_send` (key for encrypting outgoing communication)
"""
def __init__(self) -> None:
self.trezor_ephemeral_privkey: bytes
self.ck: bytes
self.k: bytes
self.h: bytes
self.key_receive: bytes
self.key_send: bytes
def handle_th1_crypto(
self,
device_properties: bytes,
host_ephemeral_pubkey: bytes,
) -> tuple[bytes, bytes, bytes]:
trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
self.trezor_ephemeral_privkey = curve25519.generate_secret()
trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
point = curve25519.multiply(
self.trezor_ephemeral_privkey, host_ephemeral_pubkey
)
self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
aes_ctx = aesgcm(self.k, IV_1)
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
if __debug__:
log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0)
aes_ctx.auth(self.h)
tag_to_encrypted_key = aes_ctx.finish()
encrypted_trezor_static_pubkey = (
encrypted_trezor_static_pubkey + tag_to_encrypted_key
)
self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
tag = aes_ctx.finish()
self.h = _hash_of_two(self.h, tag)
return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
def handle_th2_crypto(
self,
encrypted_host_static_pubkey: utils.BufferType,
encrypted_payload: utils.BufferType,
):
aes_ctx = aesgcm(self.k, IV_2)
# The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
# However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
# authentication of the gcm tag.
aes_ctx.auth(self.h) # Authenticate with the previous value of `h`
self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value
aes_ctx.decrypt_in_place(
memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
)
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1)
host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
tag = aes_ctx.finish()
if tag != encrypted_host_static_pubkey[-16:]:
raise ThpDecryptionError()
self.ck, self.k = _hkdf(
self.ck,
curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0)
tag = aes_ctx.finish()
if tag != encrypted_payload[-16:]:
raise ThpDecryptionError()
self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
self.key_receive, self.key_send = _hkdf(self.ck, b"")
if __debug__:
log.debug(
__name__,
"(key_receive: %s, key_send: %s)",
hexlify(self.key_receive),
hexlify(self.key_send),
)
def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
aes_ctx = aesgcm(self.key_send, IV_1)
encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
tag = aes_ctx.finish()
return encrypted_trezor_state + tag
def _derive_static_key_pair() -> tuple[bytes, bytes]:
node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
node = bip32.from_seed(device.get_device_secret(), "curve25519")
node.derive(node_int)
trezor_static_privkey = node.private_key()
trezor_static_pubkey = node.public_key()[1:33]
# Note: the first byte (\x01) of the public key is removed, as it
# only indicates the type of the elliptic curve used
return trezor_static_privkey, trezor_static_pubkey
def get_trezor_static_pubkey() -> bytes:
_, pubkey = _derive_static_key_pair()
return pubkey
def _hkdf(chaining_key, input: bytes):
temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
ctx = sha256(part_1)
ctx.update(part_2)
return ctx.digest()
def _get_iv_from_nonce(nonce: int) -> bytes:
utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")

View File

@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.errors import UnexpectedMessage
if TYPE_CHECKING:
from typing import Any, Callable, Coroutine
from trezor.messages import Features, GetFeatures
def find_management_session_message_handler(
msg_type: int,
) -> Callable[[Any], Coroutine[Any, Any, protobuf.MessageType]]:
if msg_type is MessageType.ThpCreateNewSession:
from apps.thp.create_session import create_new_session
return create_new_session
if msg_type is MessageType.GetFeatures:
return handle_GetFeatures
if __debug__:
if msg_type is MessageType.LoadDevice:
from apps.debug.load_device import load_device
return load_device
raise UnexpectedMessage("There is no handler available for this message")
async def handle_GetFeatures(msg: GetFeatures) -> Features:
from apps.base import get_features
return get_features()

View File

@ -0,0 +1,30 @@
from typing import TYPE_CHECKING
import usb
_MOCK_INTERFACE_HID = b"\x00"
_WIRE_INTERFACE_USB = b"\x01"
if TYPE_CHECKING:
from trezorio import WireInterface
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")

View File

@ -0,0 +1,167 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from . import ChannelState, ThpError
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
return buffer
except Exception as e:
if __debug__:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def get_write_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType
) -> utils.BufferType:
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
return _get_buffer_for_write(required_min_size, buffer)
return buffer
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_read(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__:
log.debug(
__name__,
"get_buffer_for_read - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__:
log.debug(__name__, "Allocating a new buffer")
from ..thp_main import get_raw_read_buffer
if length > len(get_raw_read_buffer()):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]
def _get_buffer_for_write(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__:
log.debug(
__name__,
"get_buffer_for_write - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__:
log.debug(__name__, "Creating a new write buffer from raw write buffer")
from ..thp_main import get_raw_write_buffer
if length > len(get_raw_write_buffer()):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]

View File

@ -0,0 +1,259 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
import trezorui2
from trezor import loop, protobuf, workflow
from trezor.crypto import random
from trezor.ui.layouts.tt import RustLayout
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import Container
from .channel import Channel
from .cpace import Cpace
pass
if __debug__:
from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc_unidirectional: bytes | None = None
def get_display_layout(self) -> RustLayout:
# TODO have different layouts when there is only QR code or only Code Entry
qr_str = ""
code_str = ""
if self.code_qr_code is not None:
qr_str = self._get_code_qr_code_str()
if self.code_code_entry is not None:
code_str = self._get_code_code_entry_str()
return RustLayout(
trezorui2.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=qr_str,
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + code_str,
path="",
xpubs=[],
)
)
def _get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def _get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.chan()
self.secret: bytes = random.bytes(16)
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
log.debug(__name__, "handle - start")
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: Message | None = None
while True:
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: Message = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(message_handler.failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await handle_pairing_request_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None:
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
async def handle_pairing_request_message(
pairing_ctx: PairingContext,
msg: protocol_common.Message,
use_workflow: bool,
) -> protocol_common.Message | None:
res_msg: protobuf.MessageType | None = None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
else:
# For debug messages, ignore workflow processing and just await
# results of the handler.
res_msg = await task
except UnexpectedMessageException as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
return exc.msg
except SilentError as exc:
if __debug__:
log.error(__name__, "SilentError: %s", exc.message)
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await pairing_ctx.write(res_msg)
return None

View File

@ -0,0 +1,394 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import (
KEY_LENGTH,
SESSION_ID_LENGTH,
TAG_LENGTH,
update_channel_last_used,
update_session_last_used,
)
from trezor import log, loop, utils
from trezor.enums import FailureType
from trezor.messages import Failure
from ..errors import DataError
from ..protocol_common import Message
from . import (
ChannelState,
SessionState,
ThpDecryptionError,
ThpError,
ThpErrorType,
ThpInvalidDataError,
ThpUnallocatedSessionError,
)
from . import alternating_bit_protocol as ABP
from . import checksum, control_byte, is_channel_state_pairing, thp_messages
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH, Handshake
from .thp_messages import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
PacketHeader,
)
from .writer import (
INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from .channel import Channel
if __debug__:
from ubinascii import hexlify
from . import state_to_str
async def handle_received_message(
ctx: Channel, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__:
log.debug(__name__, "handle_received_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_HEADER_LENGTH
_check_checksum(message_length, message_buffer)
# Synchronization process
seq_bit = (ctrl_byte & 0x10) >> 4
ack_bit = (ctrl_byte & 0x08) >> 3
if __debug__:
log.debug(
__name__,
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
seq_bit,
ack_bit,
)
# 0: Update "last-time used"
update_channel_last_used(ctx.channel_id)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, ack_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected sequential bit
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
if __debug__:
log.debug(__name__, "Received message with an unexpected sequential bit")
await _send_ack(ctx, ack_bit=seq_bit)
raise ThpError("Received message with an unexpected sequential bit")
# 3: Send ACK in response
await _send_ack(ctx, ack_bit=seq_bit)
ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
try:
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte
)
except ThpUnallocatedSessionError as e:
error_message = Failure(code=FailureType.ThpUnallocatedSession)
await ctx.write(error_message, e.session_id)
except ThpDecryptionError:
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
ctx.clear()
except ThpInvalidDataError:
await ctx.write_error(ThpErrorType.INVALID_DATA)
ctx.clear()
if __debug__:
log.debug(__name__, "handle_received_message - end")
async def _send_ack(ctx: Channel, ack_bit: int) -> None:
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, ack_bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
):
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_ack(ctx: Channel, ack_bit: int):
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct ack bit")
if ctx.transmission_loop is not None:
ctx.transmission_loop.stop_immediately()
if __debug__:
log.debug(__name__, "Stopped transmission loop")
ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in termination of this function - any code after
# this await might not be executed
async def _handle_message_to_app_or_channel(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
state = ctx.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
return
if state is ChannelState.TH1:
await _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
return
if state is ChannelState.TH2:
await _handle_state_TH2(ctx, message_length, ctrl_byte)
return
if is_channel_state_pairing(state):
await _handle_pairing(ctx, message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
ctx.handshake = Handshake()
host_ephemeral_pubkey = bytearray(
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
)
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
ctx.handshake.handle_th1_crypto(
thp_messages.get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey
)
)
if __debug__:
log.debug(
__name__,
"trezor ephemeral pubkey: %s",
hexlify(trezor_ephemeral_pubkey).decode(),
)
log.debug(
__name__,
"encrypted trezor masked static pubkey: %s",
hexlify(encrypted_trezor_static_pubkey).decode(),
)
log.debug(__name__, "tag: %s", hexlify(tag))
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
# send handshake init response message
await ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
from apps.thp.credential_manager import validate_credential
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
if ctx.handshake is None:
raise Exception("Handshake object is not prepared. Retry handshake.")
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
ctx.handshake.handle_th2_crypto(
host_encrypted_static_pubkey, handshake_completion_request_noise_payload
)
ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
noise_payload = thp_messages.decode_message(
ctx.buffer[
INIT_HEADER_LENGTH
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
enabled_methods = thp_messages.get_enabled_pairing_methods(ctx.iface)
for method in noise_payload.pairing_methods:
if method not in enabled_methods:
raise ThpInvalidDataError()
if method not in ctx.selected_pairing_methods:
ctx.selected_pairing_methods.append(method)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# key is decoded in handshake._handle_th2_crypto
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
paired: bool = False
if noise_payload.host_pairing_credential is not None:
try: # TODO change try-except for something better
paired = validate_credential(
noise_payload.host_pairing_credential,
host_static_pubkey,
)
except DataError as e:
if __debug__:
log.exception(__name__, e)
pass
trezor_state = thp_messages.TREZOR_STATE_UNPAIRED
if paired:
trezor_state = thp_messages.TREZOR_STATE_PAIRED
# send hanshake completion response
await ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
ctx.handshake.get_handshake_completion_response(trezor_state),
)
ctx.handshake = None
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
)
if session_id not in ctx.sessions:
raise ThpUnallocatedSessionError(session_id)
session_state = ctx.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
raise ThpUnallocatedSessionError(session_id)
update_session_last_used(ctx.channel_id, session_id)
ctx.sessions[session_id].incoming_message.publish(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: Channel, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True

View File

@ -0,0 +1,223 @@
from typing import TYPE_CHECKING
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure, find_handler
from ..protocol_common import Context, Message
from . import SessionState
if TYPE_CHECKING:
from typing import Any, Awaitable, Container, TypeVar, overload
from storage.cache_common import DataCache
from ..message_handler import HandlerFinder
from .channel import Channel
pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
if __debug__:
from trezor.utils import get_bytes_as_str
class GenericSessionContext(Context):
def __init__(self, channel: Channel, session_id: int) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel
self.session_id: int = session_id
self.incoming_message = loop.chan()
self.handler_finder: HandlerFinder = find_handler
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
self._handle_debug(is_debug_session)
take = self.incoming_message.take()
next_message: Message | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
# TODO modules = utils.unimport_begin()
while True:
try:
if await self._handle_message(take, next_message, is_debug_session):
return
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_message = unexpected.msg
continue
except Exception as exc:
# Log and try again.
if __debug__:
log.exception(__name__, exc)
def _handle_debug(self, is_debug_session: bool) -> None:
log.debug(
__name__,
"handle - start (channel_id (bytes): %s, session_id: %d)",
get_bytes_as_str(self.channel_id),
self.session_id,
)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
async def _handle_message(
self,
take: Awaitable[Any],
next_message: Message | None,
is_debug_session: bool,
) -> bool:
try:
message = await self._get_message(take, next_message)
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
try:
await message_handler.handle_single_message(
self,
message,
self.handler_finder,
)
except UnexpectedMessageException:
raise
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None and message.type not in AVOID_RESTARTING_FOR:
# Shut down the loop if there is no next message waiting.
return _EXIT_LOOP # pylint: disable=lost-exception
return _REPEAT_LOOP # pylint: disable=lost-exception
async def _get_message(
self, take: Awaitable[Any], next_message: Message | None
) -> Message:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
message: Message = await take
else:
# Process the message from previous run.
message = next_message
next_message = None
return message
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
def get_session_state(self) -> SessionState: ...
class ManagementSessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx, MANAGEMENT_SESSION_ID)
from trezor.wire.thp.handler_provider import (
find_management_session_message_handler,
)
self.handler_finder = find_management_session_message_handler
def get_session_state(self) -> SessionState:
return SessionState.MANAGEMENT
class SessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
session_id = int.from_bytes(session_cache.session_id, "big")
super().__init__(channel_ctx, session_id)
self.session_cache = session_cache
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
state = int.from_bytes(self.session_cache.state, "big")
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
return self.session_cache
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(
self, key: int, default: T | None = None
) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.session_cache.fields))
if self.session_cache.data[key][0] != 1:
return default
return bytes(self.session_cache.data[key][1:])
def cache_is_set(self, key: int) -> bool:
return self.session_cache.is_set(key)
def cache_set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.session_cache.fields))
utils.ensure(len(value) <= self.session_cache.fields[key])
self.session_cache.data[key][0] = 1
self.session_cache.data[key][1:] = value

View File

@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import loop
from .session_context import (
GenericSessionContext,
ManagementSessionContext,
SessionContext,
)
if __debug__:
from trezor import log
if TYPE_CHECKING:
from .channel import Channel
def create_new_session(channel_ctx: Channel) -> SessionContext:
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def create_new_management_session(
channel_ctx: Channel,
) -> ManagementSessionContext:
return ManagementSessionContext(channel_ctx)
def load_cached_sessions(
channel_ctx: Channel,
) -> dict[int, GenericSessionContext]:
if __debug__:
log.debug(__name__, "load_cached_sessions - start")
sessions: dict[int, GenericSessionContext] = {}
cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id)
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel_ctx.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel_ctx, session)
loop.schedule(sessions[sid].handle())
return sessions

View File

@ -0,0 +1,136 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf, utils
from trezor.enums import ThpPairingMethod
from trezor.messages import ThpDeviceProperties
from .. import message_handler
CODEC_V1 = const(0x3F)
CONTINUATION_PACKET = const(0x80)
HANDSHAKE_INIT_REQ = const(0x00)
HANDSHAKE_INIT_RES = const(0x01)
HANDSHAKE_COMP_REQ = const(0x02)
HANDSHAKE_COMP_RES = const(0x03)
ENCRYPTED_TRANSPORT = const(0x04)
CONTINUATION_PACKET_MASK = const(0x80)
ACK_MASK = const(0xF7)
DATA_MASK = const(0xE7)
ACK_MESSAGE = const(0x20)
_ERROR = const(0x42)
CHANNEL_ALLOCATION_REQ = const(0x40)
_CHANNEL_ALLOCATION_RES = const(0x41)
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
if __debug__:
from trezor import log
if TYPE_CHECKING:
from trezor.wire import WireInterface
class PacketHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
ustruct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
ustruct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length: int):
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
_DEFAULT_ENABLED_PAIRING_METHODS = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.QrCode,
ThpPairingMethod.NFC_Unidirectional,
]
def get_enabled_pairing_methods(
iface: WireInterface | None = None,
) -> list[ThpPairingMethod]:
import usb
l = _DEFAULT_ENABLED_PAIRING_METHODS.copy()
if iface is not None and iface is usb.iface_wire:
l.append(ThpPairingMethod.NoMethod)
return l
def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
# TODO define model variants
return ThpDeviceProperties(
pairing_methods=get_enabled_pairing_methods(iface),
internal_model=utils.INTERNAL_MODEL,
model_variant=0,
bootloader_mode=False,
protocol_version=2,
)
def get_encoded_device_properties(iface: WireInterface) -> bytes:
props = _get_device_properties(iface)
length = protobuf.encoded_length(props)
encoded_properties = bytearray(length)
protobuf.encode(encoded_properties, props)
return encoded_properties
def get_channel_allocation_response(
nonce: bytes, new_cid: bytes, iface: WireInterface
) -> bytes:
props_msg = get_encoded_device_properties(iface)
return nonce + new_cid + props_msg
def get_codec_v1_error_message() -> bytes:
# Codec_v1 magic constant "?##" + Failure message type + msg_size
# + msg_data (code = "Failure_UnexpectedMessage", message = "Invalid protocol")
ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x01\x12\x10\x49\x6e\x76\x61\x6c\x69\x64\x20\x70\x72\x6f\x74\x6f\x63\x6f\x6c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
return ERROR_MSG
def decode_message(
buffer: bytes, msg_type: int, message_name: str | None = None
) -> protobuf.MessageType:
if __debug__:
log.debug(__name__, "decode message")
if message_name is not None:
expected_type = protobuf.type_for_name(message_name)
else:
expected_type = protobuf.type_for_wire(msg_type)
x = message_handler.wrap_protobuf_load(buffer, expected_type)
return x

View File

@ -0,0 +1,59 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import loop
from trezor.wire.thp.thp_messages import PacketHeader
from trezor.wire.thp.writer import write_payload_to_wire_and_add_checksum
if TYPE_CHECKING:
from trezor.wire.thp.channel import Channel
MAX_RETRANSMISSION_COUNT = const(50)
MIN_RETRANSMISSION_COUNT = const(2)
class TransmissionLoop:
def __init__(
self, channel: Channel, header: PacketHeader, transport_payload: bytes
) -> None:
self.channel: Channel = channel
self.header: PacketHeader = header
self.transport_payload: bytes = transport_payload
self.wait_task: loop.spawn | None = None
self.min_retransmisson_count_achieved: bool = False
async def start(self):
self.min_retransmisson_count_achieved = False
for i in range(MAX_RETRANSMISSION_COUNT):
if i >= MIN_RETRANSMISSION_COUNT:
self.min_retransmisson_count_achieved = True
await write_payload_to_wire_and_add_checksum(
self.channel.iface, self.header, self.transport_payload
)
self.wait_task = loop.spawn(self._wait(i))
try:
await self.wait_task
except loop.TaskClosed:
self.wait_task = None
break
def stop_immediately(self):
if self.wait_task is not None:
self.wait_task.close()
self.wait_task = None
async def stop_after_min_retransmission(self):
while not self.min_retransmisson_count_achieved and self.wait_task is not None:
await self._short_wait()
self.stop_immediately()
async def _wait(self, counter: int = 0) -> None:
timeout_ms = round(10200 - 1010000 / (counter + 100))
await loop.sleep(timeout_ms)
async def _short_wait(self):
loop.wait(50)
def __del__(self):
self.stop_immediately()

View File

@ -0,0 +1,82 @@
from micropython import const
from trezorcrypto import crc
from typing import TYPE_CHECKING
from trezor import io, log, loop, utils
from trezor.wire.thp.thp_messages import PacketHeader
INIT_HEADER_LENGTH = const(5)
CONT_HEADER_LENGTH = const(3)
PACKET_LENGTH = const(64)
CHECKSUM_LENGTH = const(4)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Sequence
async def write_payload_to_wire_and_add_checksum(
iface: WireInterface, header: PacketHeader, transport_payload: bytes
):
header_checksum: int = crc.crc32(header.to_bytes())
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
CHECKSUM_LENGTH, "big"
)
data = (transport_payload, checksum)
await write_payloads_to_wire(iface, header, data)
async def write_payloads_to_wire(
iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
):
n_of_data = len(data)
total_length = sum(len(item) for item in data)
current_data_idx = 0
current_data_offset = 0
packet = bytearray(PACKET_LENGTH)
header.pack_to_init_buffer(packet)
packet_offset: int = INIT_HEADER_LENGTH
packet_number = 0
nwritten = 0
while nwritten < total_length:
if packet_number == 1:
header.pack_to_cont_buffer(packet)
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
packet[:] = bytearray(PACKET_LENGTH)
header.pack_to_cont_buffer(packet)
while True:
n = utils.memcpy(
packet, packet_offset, data[current_data_idx], current_data_offset
)
packet_offset += n
current_data_offset += n
nwritten += n
if packet_offset < PACKET_LENGTH:
current_data_idx += 1
current_data_offset = 0
if current_data_idx >= n_of_data:
break
elif packet_offset == PACKET_LENGTH:
break
else:
raise Exception("Should not happen!!!")
packet_number += 1
packet_offset = CONT_HEADER_LENGTH
await write_packet_to_wire(iface, packet)
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
while True:
await loop.wait(iface.iface_num() | io.POLL_WRITE)
if __debug__:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
n = iface.write(packet)
if n == len(packet):
return

View File

@ -0,0 +1,177 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils
from trezor.wire.thp import writer
from .thp import (
ChannelState,
ThpError,
ThpErrorType,
channel_manager,
checksum,
session_manager,
thp_messages,
)
from .thp.channel import Channel
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, PacketHeader
from .thp.writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
PACKET_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12)
_READ_BUFFER: bytearray
_WRITE_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {}
def set_read_buffer(buffer: bytearray):
global _READ_BUFFER
_READ_BUFFER = buffer
def set_write_buffer(buffer: bytearray):
global _WRITE_BUFFER
_WRITE_BUFFER = buffer
def get_raw_read_buffer() -> bytearray:
global _WRITE_BUFFER
return _READ_BUFFER
def get_raw_write_buffer() -> bytearray:
global _WRITE_BUFFER
return _WRITE_BUFFER
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
global _CHANNELS
global _READ_BUFFER
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
for ch in _CHANNELS.values():
ch.sessions.update(session_manager.load_cached_sessions(ch))
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
try:
if __debug__:
log.debug(__name__, "thp_main_loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
await _handle_codec_v1(iface, packet)
continue
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in _CHANNELS:
await _handle_allocated(iface, cid, packet)
else:
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
log.exception(__name__, e)
async def _handle_codec_v1(iface: WireInterface, packet):
# If the received packet is not initial codec_v1 packet, do not send error message
if not packet[1:3] == b"##":
return
if __debug__:
log.debug(__name__, "Received codec_v1 message, returning error")
error_message = thp_messages.get_codec_v1_error_message()
await writer.write_packet_to_wire(iface, error_message)
async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None:
global _READ_BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on the broadcast channel")
length, nonce = ustruct.unpack(">H8s", packet[3:])
payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
if not checksum.is_valid(
payload[-4:],
packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
):
raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel
response_data = thp_messages.get_channel_allocation_response(
nonce, new_channel.channel_id, iface
)
response_header = PacketHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", cid)
await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = _CHANNELS[cid]
if channel is None:
await _handle_unallocated(iface, cid)
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
async def _handle_unallocated(iface, cid) -> None:
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
await write_payload_to_wire_and_add_checksum(iface, header, data)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(PACKET_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]

View File

@ -1,7 +1,7 @@
import utime
from typing import TYPE_CHECKING
import storage.cache
import storage.cache_common as cache_common
from trezor import log, loop
from trezor.enums import MessageType
@ -153,7 +153,7 @@ def close_others() -> None:
if not task.is_running():
task.close()
storage.cache.homescreen_shown = None
cache_common.homescreen_shown = None
# if tasks were running, closing the last of them will run start_default
@ -211,11 +211,11 @@ class IdleTimer:
time and saves it to storage.cache. This is done to avoid losing an
active timer when workflow restart happens and tasks are lost.
"""
if _restore_from_cache and storage.cache.autolock_last_touch is not None:
now = storage.cache.autolock_last_touch
if _restore_from_cache and cache_common.autolock_last_touch is not None:
now = cache_common.autolock_last_touch
else:
now = utime.ticks_ms()
storage.cache.autolock_last_touch = now
cache_common.autolock_last_touch = now
for callback, task in self.tasks.items():
timeout_us = self.timeouts[callback]

View File

@ -0,0 +1,17 @@
from trezor.loop import wait
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)

42
core/tests/myTests.sh Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env bash
declare -a results
declare -i passed=0 failed=0 exit_code=0
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
print_summary() {
echo
echo 'Summary:'
echo '-------------------'
printf '%b\n' "${results[@]}"
if [ $exit_code == 0 ]; then
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
else
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
fi
}
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
cd $(dirname $0)
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
declare -i num_of_tests=${#tests[@]}
for test_case in ${tests[@]}; do
echo ${MICROPYTHON}
echo ${test_case}
echo
if $MICROPYTHON $test_case; then
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
((passed++))
else
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
((failed++))
exit_code=1
fi
done
print_summary
exit $exit_code

View File

@ -1,4 +1,4 @@
from common import H_, await_result, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor import wire
@ -11,6 +11,7 @@ from trezor.messages import (
TxInput,
TxOutput,
)
from trezor.wire import context
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
@ -18,8 +19,26 @@ from apps.bitcoin.sign_tx.bitcoin import Bitcoin
from apps.bitcoin.sign_tx.tx_info import TxInfo
from apps.common import coins
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
class TestApprover(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
self.coin = coins.by_name("Bitcoin")
self.fee_rate_percent = 0.3
@ -47,7 +66,8 @@ class TestApprover(unittest.TestCase):
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDTAPROOT,
)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def make_coinjoin_request(self, inputs):
return CoinJoinRequest(

View File

@ -1,16 +1,35 @@
from common import H_, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
from trezor.wire import context
from apps.bitcoin.authorization import CoinJoinAuthorization
from apps.common import coins
_ROUND_ID_LEN = 32
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
class TestAuthorization(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
coin = coins.by_name("Bitcoin")
@ -26,7 +45,8 @@ class TestAuthorization(unittest.TestCase):
)
self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch.

View File

@ -1,17 +1,41 @@
from common import * # isort:skip
from storage import cache
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.wire import context
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase):
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self):
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin")
@ -88,10 +112,20 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase):
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self):
coin = _get_coin_by_name("Bcash")

View File

@ -1,19 +1,43 @@
from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache
from trezor import wire
from storage import cache, cache_common
from trezor import utils, wire
from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema
if utils.USE_THP:
import thp_common
if not utils.USE_THP:
from storage import cache_codec
class TestKeychain(unittest.TestCase):
def setUp(self):
cache.start_session()
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
def cache_set(self, key: int, value: bytes) -> None:
context.cache_set(key, value)
def tearDown(self):
cache.clear_all()
@ -71,7 +95,7 @@ class TestKeychain(unittest.TestCase):
def test_get_keychain(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
self.cache_set(cache_common.APP_COMMON_SEED, seed)
schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(get_keychain("secp256k1", [schema]))
@ -85,7 +109,7 @@ class TestKeychain(unittest.TestCase):
def test_with_slip44(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
self.cache_set(cache_common.APP_COMMON_SEED, seed)
slip44_id = 42
valid_path = [H_(44), H_(slip44_id), H_(0)]

View File

@ -2,13 +2,20 @@ from common import * # isort:skip
import unittest
from storage import cache
from trezor import utils, wire
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.wire import context
from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network
from trezor.messages import (
@ -71,10 +78,27 @@ class TestEthereumKeychain(unittest.TestCase):
addr,
)
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self):
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n)

View File

@ -1,230 +1,518 @@
from common import * # isort:skip
from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage
from storage import cache
from storage import cache, cache_codec
from trezor.messages import EndSession, Initialize
from apps.base import handle_EndSession, handle_Initialize
KEY = 0
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from storage import cache_thp
from trezor.wire.thp import ChannelState
from trezor.wire.thp.session_context import ManagementSessionContext, SessionContext
# Function moved from cache.py, as it was not used there
def is_session_started() -> bool:
return cache._active_session_idx is not None
_PROTOCOL_CACHE = cache_thp
else:
_PROTOCOL_CACHE = cache_codec
def is_session_started() -> bool:
return cache_codec.get_active_session() is not None
def get_active_session():
return cache_codec.get_active_session()
class TestStorageCache(unittest.TestCase):
class TestStorageCache(
unittest.TestCase
): # noqa: F405 # pyright: ignore[reportUndefinedVariable]
def setUp(self):
cache.clear_all()
def test_start_session(self):
session_id_a = cache.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache.start_session()
self.assertNotEqual(session_id_a, session_id_b)
if utils.USE_THP:
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
def __init__(self):
thp_common.suppres_debug_log()
# xthp_common.prepare_context()
# config.init()
super().__init__()
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(is_session_started())
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
cache.clear_all()
# ending an ended session should be a no-op
cache.end_current_session()
self.assertFalse(is_session_started())
def test_new_channel_and_session(self):
channel = thp_common.get_new_channel(self.interface)
session_id_a = cache.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
# Assert that channel is created with one management session
self.assertEqual(len(channel.sessions), 1)
self.assertIsInstance(channel.sessions[0], ManagementSessionContext)
# create a new session
session_id_b = cache.start_session()
# switch back to original session
session_id = cache.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache.end_current_session()
# switch back to B
session_id = cache.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
cid_1 = channel.channel_id
session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
session_1 = SessionContext(channel, session_cache_1)
self.assertEqual(session_1.channel_id, cid_1)
def test_session_queue(self):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
session_2 = SessionContext(channel, session_cache_2)
self.assertEqual(session_2.channel_id, cid_1)
self.assertEqual(session_1.channel_id, session_2.channel_id)
self.assertNotEqual(session_1.session_id, session_2.session_id)
def test_get_set(self):
session_id1 = cache.start_session()
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
channel_2 = thp_common.get_new_channel(self.interface)
cid_2 = channel_2.channel_id
self.assertNotEqual(cid_1, cid_2)
session_id2 = cache.start_session()
cache.set(KEY, b"world")
self.assertEqual(cache.get(KEY), b"world")
session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
session_3 = SessionContext(channel_2, session_cache_3)
self.assertEqual(session_3.channel_id, cid_2)
cache.start_session(session_id2)
self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
# Sessions 1 and 3 should have different channel_id, but the same session_id
self.assertNotEqual(session_1.channel_id, session_3.channel_id)
self.assertEqual(session_1.session_id, session_3.session_id)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
def test_get_set_int(self):
session_id1 = cache.start_session()
cache.set_int(KEY, 1234)
self.assertEqual(cache.get_int(KEY), 1234)
# Check that session data IS in cache for created sessions ONLY
for i in range(3):
self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
session_id2 = cache.start_session()
cache.set_int(KEY, 5678)
self.assertEqual(cache.get_int(KEY), 5678)
# Check that session data IS NOT in cache after cache.clear_all()
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.channel_id, b"")
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.state, b"\x00")
cache.start_session(session_id2)
self.assertEqual(cache.get_int(KEY), 5678)
cache.start_session(session_id1)
self.assertEqual(cache.get_int(KEY), 1234)
def test_channel_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
channels = []
for i in range(cache_thp._MAX_CHANNELS_COUNT):
channels.append(thp_common.get_new_channel(self.interface))
channel_ids = [channel.channel_cache.channel_id for channel in channels]
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.get_int(KEY)
# Assert that each channel_id is unique and that cache and list of channels
# have the same "channels" on the same indexes
for i in range(len(channel_ids)):
self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
for j in range(i + 1, len(channel_ids)):
self.assertNotEqual(channel_ids[i], channel_ids[j])
def test_delete(self):
session_id1 = cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
# Create a new channel that is over the capacity
new_channel = thp_common.get_new_channel(self.interface)
for c in channels:
self.assertNotEqual(c.channel_id, new_channel.channel_id)
cache.set(KEY, b"hello")
cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
# Test that the oldest (least used) channel was replaced (_CHANNELS[0])
self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
# Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
# assert that it is not replaced when creating a new channel
cache_thp.update_channel_last_used(channel_ids[1])
new_new_channel = thp_common.get_new_channel(self.interface)
self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
def test_decorators(self):
run_count = 0
cache.start_session()
# Assert that it was in fact the _CHANNEL[2] that was replaced
self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
self.assertEqual(
cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
)
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
def test_session_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
# cache is empty
self.assertIsNone(cache.get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
sesions_A = []
cid = []
sid = []
for i in range(3):
sesions_A.append(cache_thp.get_new_session(channel_cache_A))
cid.append(sesions_A[i].channel_id)
sid.append(sesions_A[i].session_id)
@cache.stored_async(KEY)
async def async_func():
nonlocal run_count
run_count += 1
return b"bar"
sessions_B = []
for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
sessions_B.append(cache_thp.get_new_session(channel_cache_B))
# cache is still full
self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1)
for i in range(3):
self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
cache.start_session()
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# awaitable is also run only once
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
new_session = cache_thp.get_new_session(channel_cache_B)
self.assertEqual(new_session, cache_thp._SESSIONS[0])
self.assertNotEqual(new_session.channel_id, cid[0])
self.assertNotEqual(new_session.session_id, sid[0])
def test_empty_value(self):
cache.start_session()
# Assert that updating "last used" for session on channel A increases also
# the "last usage" of channel A.
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
cache_thp.update_session_last_used(
channel_cache_A.channel_id, sesions_A[1].session_id
)
self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"")
self.assertEqual(cache.get(KEY), b"")
new_new_session = cache_thp.get_new_session(channel_cache_B)
cache.delete(KEY)
run_count = 0
# Assert that creating a new session on channel B shifts the "last usage" again
# and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
def test_clear(self):
channel_A = thp_common.get_new_channel(self.interface)
channel_B = thp_common.get_new_channel(self.interface)
cid_A = channel_A.channel_id
cid_B = channel_B.channel_id
sessions = []
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
for i in range(3):
sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
# calling Initialize without an ID allocates a new one
session_id = cache.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# Assert that clearing of channel A works
self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
# store "hello"
cache.set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(cache.get(KEY))
# store "hello" again
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
channel_A.clear()
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
self.assertEqual(channel_A.channel_cache.channel_id, b"")
self.assertEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), b"hello")
# Assert that clearing channel A also cleared all its sessions
for i in range(3):
self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
def test_EndSession(self):
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
cache.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.channel_id, b"")
for channel in cache_thp._CHANNELS:
self.assertEqual(channel.channel_id, b"")
self.assertEqual(channel.last_usage, 0)
self.assertEqual(
cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
)
def test_get_set(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set(KEY, b"world")
self.assertEqual(session_2.get(KEY), b"world")
self.assertEqual(session_1.get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(session_1.get(KEY))
self.assertIsNone(session_2.get(KEY))
def test_get_set_int(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set_int(KEY, 1234)
self.assertEqual(session_1.get_int(KEY), 1234)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set_int(KEY, 5678)
self.assertEqual(session_2.get_int(KEY), 5678)
self.assertEqual(session_1.get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(session_1.get_int(KEY))
self.assertIsNone(session_2.get_int(KEY))
def test_get_set_bool(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
with self.assertRaises(AssertionError) as e:
session_1.set_bool(KEY, True)
self.assertEqual(e.value.value, "Field does not have zero length!")
# Change length of first session field to 0 so that the length check passes
session_1.fields = (0,) + session_1.fields[1:]
# with self.assertRaises(AssertionError) as e:
session_1.set_bool(KEY, True)
self.assertEqual(session_1.get_bool(KEY), True)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
session_2.set_bool(KEY, False)
self.assertEqual(session_2.get_bool(KEY), False)
self.assertEqual(session_1.get_bool(KEY), True)
cache.clear_all()
# Default value is False
self.assertFalse(session_1.get_bool(KEY))
self.assertFalse(session_2.get_bool(KEY))
def test_delete(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_1.delete(KEY)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_2.get(KEY))
session_2.set(KEY, b"hello")
self.assertEqual(session_2.get(KEY), b"hello")
session_2.delete(KEY)
self.assertIsNone(session_2.get(KEY))
self.assertEqual(session_1.get(KEY), b"hello")
else:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def test_start_session(self):
session_id_a = cache_codec.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
self.assertIsNone(get_active_session())
for session in cache_codec._SESSIONS:
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
def test_end_session(self):
session_id = cache_codec.start_session()
self.assertTrue(is_session_started())
get_active_session().set(KEY, b"A")
cache_codec.end_current_session()
self.assertFalse(is_session_started())
self.assertIsNone(get_active_session())
# ending an ended session should be a no-op
cache_codec.end_current_session()
self.assertFalse(is_session_started())
session_id_a = cache_codec.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(get_active_session().get(KEY))
# create a new session
session_id_b = cache_codec.start_session()
# switch back to original session
session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache_codec.end_current_session()
# switch back to B
session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self):
session_id = cache_codec.start_session()
self.assertEqual(cache_codec.start_session(session_id), session_id)
get_active_session().set(KEY, b"A")
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache_codec.start_session()
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
self.assertIsNone(get_active_session().get(KEY))
def test_get_set(self):
session_id1 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"hello")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
session_id2 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"world")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id2)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id1)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(cache_codec.get_active_session())
def test_get_set_int(self):
session_id1 = cache_codec.start_session()
get_active_session().set_int(KEY, 1234)
self.assertEqual(get_active_session().get_int(KEY), 1234)
session_id2 = cache_codec.start_session()
get_active_session().set_int(KEY, 5678)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id2)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(get_active_session())
def test_delete(self):
session_id1 = cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get(KEY), b"hello")
def test_decorators(self):
run_count = 0
cache_codec.start_session()
from apps.common.cache import stored
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
# cache is empty
self.assertIsNone(get_active_session().get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(get_active_session().get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
def test_empty_value(self):
cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"")
self.assertEqual(get_active_session().get(KEY), b"")
get_active_session().delete(KEY)
run_count = 0
from apps.common.cache import stored
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
# calling Initialize without an ID allocates a new one
session_id = cache_codec.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# store "hello"
get_active_session().set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(get_active_session().get(KEY))
# store "hello" again
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
# supplying a different session ID starts a new session
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(get_active_session().get(KEY))
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(get_active_session().get(KEY), b"hello")
def test_EndSession(self):
self.assertIsNone(get_active_session())
cache_codec.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(get_active_session().get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertIsNone(cache_codec.get_active_session())
if __name__ == "__main__":

View File

@ -2,28 +2,11 @@ from common import * # isort:skip
import ustruct
from mock_wire_interface import MockHID
from trezor import io
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")

View File

@ -0,0 +1,94 @@
from common import * # isort:skip
if utils.USE_THP:
from trezor.wire.thp import checksum
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolChecksum(unittest.TestCase):
vectors_correct = [
(
b"",
b"\x00\x00\x00\x00",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAD",
),
(
bytes("a", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("123456789", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
(
b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
b"\x9B\xD3\x66\xAE",
),
(
b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
b"\x6B\xA4\xEC\x92",
),
]
vectors_incorrect = [
(
b"",
b"\x00\x00\x00\x00\x00",
),
(
b"",
b"",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAE",
),
(
bytes("A", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc ", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("1234567890", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"1234567890123456789012345678901234567890123456789012345678901234567890123456789",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
]
def test_computation(self):
for data, chksum in self.vectors_correct:
self.assertEqual(checksum.compute(data), chksum)
def test_validation_correct(self):
for data, chksum in self.vectors_correct:
self.assertTrue(checksum.is_valid(chksum, data))
def test_validation_incorrect(self):
for data, chksum in self.vectors_incorrect:
self.assertFalse(checksum.is_valid(chksum, data))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,66 @@
from common import * # isort:skip
if utils.USE_THP:
import thp_common
from trezor import config
from trezor.messages import ThpCredentialMetadata
from apps.thp import credential_manager
def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes:
metadata = ThpCredentialMetadata(host_name=host_name)
return credential_manager.issue_credential(host_static_pubkey, metadata)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCredentialManager(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
config.init()
config.wipe()
def test_derive_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertEqual(len(key1), 32)
self.assertEqual(key1, key2)
def test_invalidate_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
credential_manager.invalidate_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertNotEqual(key1, key2)
def test_credentials(self):
DUMMY_KEY_1 = b"\x00\x00"
DUMMY_KEY_2 = b"\xff\xff"
HOST_NAME_1 = "host_name"
HOST_NAME_2 = "different host_name"
cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertEqual(cred_1, cred_2)
cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_3)
self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2))
credential_manager.invalidate_cred_auth_key()
cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_4)
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,156 @@
from common import * # isort:skip
from trezorcrypto import aesgcm, curve25519
import storage
if utils.USE_THP:
import thp_common
from trezor.wire.thp import crypto
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
def get_dummy_device_secret():
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCrypto(unittest.TestCase):
if utils.USE_THP:
handshake = Handshake()
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
vectors_enc = [
(
key_1,
0,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"e2c9dd152fbee5821ea7",
b"10625812de81b14a46b9f1e5100a6d0c",
),
(
key_1,
1,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"79811619ddb07c2b99f8",
b"71c6b872cdc499a7e9a3c7441f053214",
),
(
key_1,
369,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"c1200f8a7ae9a6d32cef0fff878d55c2",
),
(
key_1,
369,
b"\x55\x64\x73\x82\x91",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"693ac160cd93a20f7fc255f049d808d0",
),
]
# 0:chaining key, 1:input, 2:output_1, 3:output:2
vectors_hkdf = [
(
crypto.PROTOCOL_NAME,
b"\x01\x02",
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
),
(
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
),
]
vectors_iv = [
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
]
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
utils.DISABLE_ENCRYPTION = False
def test_encryption(self):
for v in self.vectors_enc:
buffer = bytearray(v[3])
tag = crypto.enc(buffer, v[0], v[1], v[2])
self.assertEqual(hexlify(buffer), v[4])
self.assertEqual(hexlify(tag), v[5])
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
self.assertEqual(buffer, v[3])
def test_hkdf(self):
for v in self.vectors_hkdf:
ck, k = crypto._hkdf(v[0], v[1])
self.assertEqual(hexlify(ck), v[2])
self.assertEqual(hexlify(k), v[3])
def test_iv_from_nonce(self):
for v in self.vectors_iv:
x = v[0]
y = x.to_bytes(8, "big")
iv = crypto._get_iv_from_nonce(v[0])
self.assertEqual(iv, v[1])
with self.assertRaises(AssertionError) as e:
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
def test_incorrect_vectors(self):
pass
def test_th1_crypto(self):
storage.device.get_device_secret = get_dummy_device_secret
handshake = self.handshake
host_ephemeral_privkey = curve25519.generate_secret()
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
def test_th2_crypto(self):
handshake = self.handshake
host_static_privkey = curve25519.generate_secret()
host_static_pubkey = curve25519.publickey(host_static_privkey)
aes_ctx = aesgcm(handshake.k, IV_2)
aes_ctx.auth(handshake.h)
encrypted_host_static_pubkey = bytearray(
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
)
# Code to encrypt Host's noise encrypted payload correctly:
protomsg = bytearray(b"\x10\x02\x10\x03")
temp_k = handshake.k
temp_h = handshake.h
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
_, temp_k = crypto._hkdf(
handshake.ck,
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(temp_k, IV_1)
aes_ctx.encrypt_in_place(protomsg)
aes_ctx.auth(temp_h)
tag = aes_ctx.finish()
encrypted_payload = bytearray(protomsg + tag)
# end of encrypted payload generation
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,370 @@
from common import * # isort:skip
from mock_wire_interface import MockHID
from trezor import config, io, protobuf
from trezor.crypto.curve import curve25519
from trezor.enums import MessageType
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.protocol_common import Message
if utils.USE_THP:
from typing import TYPE_CHECKING
import thp_common
from storage import cache_thp
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from trezor.crypto import elligator2
from trezor.enums import ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCpaceHost,
ThpCodeEntryTag,
ThpCredentialRequest,
ThpEndRequest,
ThpStartPairingRequest,
)
from trezor.wire import thp_main
from trezor.wire.thp import ChannelState, checksum, interface_manager
from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext
from apps.thp import pairing
if TYPE_CHECKING:
from trezor.wire import WireInterface
def get_dummy_key() -> bytes:
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
def send_channel_allocation_request(
interface: WireInterface, nonce: bytes | None = None
) -> bytes:
if nonce is None or len(nonce) != 8:
nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
header = b"\x40\xff\xff\x00\x0c"
chksum = checksum.compute(header + nonce)
cid_req = header + nonce + chksum
gen = thp_main.thp_main_loop(interface, is_debug_session=True)
gen.send(None)
gen.send(cid_req)
gen.send(None)
response_data = (
b"\x0a\x04\x54\x32\x54\x31\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
)
response_without_crc = (
b"\x41\xff\xff\x00\x20"
+ nonce
+ cache_thp.cid_counter.to_bytes(2, "big")
+ response_data
)
chkcsum = checksum.compute(response_without_crc)
expected_response = response_without_crc + chkcsum + b"\x00" * 27
return expected_response
def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
return int.from_bytes(channel_allocation_response[13:15], "big")
def get_ack(channel_id: bytes) -> bytes:
if len(channel_id) != 2:
raise Exception("Channel id should by two bytes long")
return (
b"\x20"
+ channel_id
+ b"\x00\x04"
+ checksum.compute(b"\x20" + channel_id + b"\x00\x04")
+ b"\x00" * 55
)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocol(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_main.set_read_buffer(buffer)
interface_manager.decode_iface = thp_common.dummy_decode_iface
def test_codec_message(self):
self.assertEqual(len(self.interface.data), 0)
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
gen.send(None)
# There should be a failiure response to received init packet (starts with "?##")
test_codec_message = b"?## Some data"
gen.send(test_codec_message)
gen.send(None)
self.assertEqual(len(self.interface.data), 1)
expected_response = (
b"?##\x00\x03\x00\x00\x00\x14\x08\x01\x12\x10Invalid protocol"
)
self.assertEqual(
self.interface.data[-1][: len(expected_response)], expected_response
)
# There should be no response for continuation packet (starts with "?" only)
test_codec_message_2 = b"? Cont packet"
gen.send(test_codec_message_2)
with self.assertRaises(TypeError) as e:
gen.send(None)
self.assertEqual(e.value.value, "object with buffer protocol required")
self.assertEqual(len(self.interface.data), 1)
def test_message_on_unallocated_channel(self):
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
message_to_channel_789a = (
b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
gen.send(message_to_channel_789a)
gen.send(None)
unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
utils.get_bytes_as_str(self.interface.data[-1]),
unallocated_chanel_error_on_channel_789a,
)
def test_channel_allocation(self):
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_main._CHANNELS), 0)
self.assertFalse(test_counter in thp_main._CHANNELS)
expected_response = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-1], expected_response)
self.assertTrue(test_counter in thp_main._CHANNELS)
self.assertEqual(len(thp_main._CHANNELS), 1)
# test channel's default state is TH1:
cid = get_channel_id_from_response(self.interface.data[-1])
self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
def test_invalid_encrypted_tag(self):
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
def test_channel_errors(self):
gen = thp_main.thp_main_loop(self.interface, is_debug_session=True)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
# test invalid tag in handshake phase
cid_2 = get_channel_id_from_response(expected_response_1)
cid_2_bytes = cid_2.to_bytes(2, "big")
channel = thp_main._CHANNELS[cid_2]
channel.iface = self.interface
channel.set_channel_state(ChannelState.TH2)
message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
# gen.send(message_with_invalid_tag)
# gen.send(None)
# gen.send(None)
# for i in self.interface.data:
# print(utils.get_bytes_as_str(i))
def test_skip_pairing(self):
config.init()
config.wipe()
channel = thp_main._CHANNELS[4660]
channel.selected_pairing_methods = [
ThpPairingMethod.NoMethod,
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
# Teardown: set back initial channel state value
channel.set_channel_state(ChannelState.TH1)
def test_pairing(self):
config.init()
config.wipe()
cid = get_channel_id_from_response(
send_channel_allocation_request(self.interface)
)
channel = thp_main._CHANNELS[cid]
channel.selected_pairing_methods = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
with self.assertRaises(UnexpectedMessage) as e:
pairing.handle_pairing_request(pairing_ctx, request_message)
print(e.value.message)
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
gen.send(None)
async def _dummy(ctx: PairingContext, expected_types):
return await ctx.read([1018, 1024])
pairing.show_display_data = _dummy
msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
protobuf.encode(buffer, msg_code_entry)
code_entry_challenge = Message(MessageType.ThpCodeEntryChallenge, buffer)
gen.send(code_entry_challenge)
# tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
# tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
generator_host = elligator2.map_to_curve25519(pregenerator_host)
cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
cpace_host_public_key: bytes = curve25519.multiply(
cpace_host_private_key, generator_host
)
msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
# msg = ThpQrCodeTag(tag=tag_qrc)
# msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(MessageType.ThpCodeEntryCpaceHost, buffer)
gen.send(user_message)
tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
msg = ThpCodeEntryTag(tag=tag_ent)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(MessageType.ThpCodeEntryTag, buffer)
gen.send(user_message)
host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
credential_request = Message(MessageType.ThpCredentialRequest, buffer)
gen.send(credential_request)
msg = ThpEndRequest()
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
end_request = Message(1012, buffer)
with self.assertRaises(StopIteration) as e:
gen.send(end_request)
print("response message:", e.value.value.MESSAGE_NAME)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,150 @@
from common import * # isort:skip
from typing import Any, Awaitable
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from trezor.wire.thp import writer
from trezor.wire.thp.thp_messages import ENCRYPTED_TRANSPORT, PacketHeader
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolWriter(unittest.TestCase):
short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
eight_longer_payloads_expected = [
b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
]
empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_with_checksum_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
def await_until_result(self, task: Awaitable) -> Any:
with self.assertRaises(StopIteration):
while True:
task.send(None)
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_write_empty_packet(self):
self.await_until_result(writer.write_packet_to_wire(self.interface, b""))
print(self.interface.data[0])
self.assertEqual(len(self.interface.data), 1)
self.assertEqual(self.interface.data[0], b"")
def test_write_empty_payload(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4)
await_result(writer.write_payloads_to_wire(self.interface, header, (b"",)))
self.assertEqual(len(self.interface.data), 0)
def test_write_short_payload(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 5)
data = b"\x07"
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
def test_write_longer_payload(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256)
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
for i in range(len(self.longer_payload_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.longer_payload_expected[i]
)
def test_write_eight_longer_payloads(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 2048)
self.await_until_result(
writer.write_payloads_to_wire(
self.interface, header, (data, data, data, data, data, data, data, data)
)
)
for i in range(len(self.eight_longer_payloads_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
)
def test_write_empty_payload_with_checksum(self):
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
)
self.assertEqual(
hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
)
def test_write_longer_payload_with_checksum(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, data)
)
for i in range(len(self.longer_payload_with_checksum_expected)):
self.assertEqual(
hexlify(self.interface.data[i]),
self.longer_payload_with_checksum_expected[i],
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,338 @@
from common import * # isort:skip
import ustruct
from typing import TYPE_CHECKING
from mock_wire_interface import MockHID
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io
from trezor.utils import chunks
from trezor.wire.protocol_common import Message
if utils.USE_THP:
import thp_common
import trezor.wire.thp
from trezor.wire import thp_main
from trezor.wire.thp import alternating_bit_protocol as ABP
from trezor.wire.thp import checksum
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.writer import PACKET_LENGTH
if TYPE_CHECKING:
from trezorio import WireInterface
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
if utils.USE_THP:
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def getPlaintext() -> bytes:
if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
return PLAINTEXT_0
async def deprecated_read_message(
iface: WireInterface, buffer: utils.BufferType
) -> Message:
return Message(-1, b"\x00")
async def deprecated_write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
pass
# This test suite is an adaptation of test_trezor.wire.codec_v1
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req_message)
gen.send(None)
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def _roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, 1
) # TODO use different session id
gen = deprecated_write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
print(utils.get_bytes_as_str(packet))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def _write_one_packet(self):
message = Message(MESSAGE_TYPE, b"")
gen = deprecated_write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(MESSAGE_TYPE, message_payload)
gen = deprecated_write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(trezor.wire.thp.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()

43
core/tests/thp_common.py Normal file
View File

@ -0,0 +1,43 @@
from trezor import utils
from trezor.wire.thp import ChannelState
if utils.USE_THP:
import unittest
from typing import TYPE_CHECKING, Any, Awaitable
from mock_wire_interface import MockHID
from storage import cache_thp
from trezor.wire import context
from trezor.wire.thp import interface_manager
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.interface_manager import _MOCK_INTERFACE_HID
from trezor.wire.thp.session_context import SessionContext
if TYPE_CHECKING:
from trezor.wire import WireInterface
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
interface_manager.decode_iface = dummy_decode_iface
channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
channel = Channel(channel_cache)
channel.set_channel_state(ChannelState.TH1)
if channel_iface is not None:
channel.iface = channel_iface
return channel
def prepare_context() -> None:
channel = get_new_channel()
session_cache = cache_thp.get_new_session(channel.channel_cache)
session_ctx = SessionContext(channel, session_cache)
context.CURRENT_CONTEXT = session_ctx
if __debug__:
# Disable log.debug
def suppres_debug_log() -> None:
from trezor import log
log.debug = lambda name, msg, *args: None

View File

@ -33,7 +33,7 @@
},
"header": {
"language": "cs-CZ",
"version": "2.8.1"
"version": "2.8.2"
},
"translations": {
"addr_mismatch__contact_support_at": "Kontaktujte naši podporu na",
@ -519,7 +519,6 @@
"passphrase__access_wallet": "Otev. passphrase pen.?",
"passphrase__always_on_device": "Vždy zadávat passphrase na Trezoru?",
"passphrase__from_host_not_shown": "Použije se passphrase zadaná hostitelem, ale vzhledem k nastavení zařízení se nezobrazí.",
"passphrase__wallet": "Passphrase pen.",
"passphrase__hide": "Skrýt passphrase od hostitele?",
"passphrase__next_screen_will_show_passphrase": "Na další obrazovce se zobrazí vaše passphrase.",
"passphrase__please_enter": "Zadejte passphrase.",
@ -532,6 +531,7 @@
"passphrase__title_source": "Zdroj passphrase",
"passphrase__turn_off": "Vypnout ochranu s passphrase?",
"passphrase__turn_on": "Zapnout ochranu s passphrase?",
"passphrase__wallet": "Passphrase pen.",
"pin__cancel_description": "Pokračovat bez PIN kódu",
"pin__cancel_info": "Bez PIN kódu může k tomuto zařízení přistupovat kdokoli.",
"pin__cancel_setup": "Zrušit nastavení PIN kódu",

View File

@ -33,7 +33,7 @@
},
"header": {
"language": "de-DE",
"version": "2.8.1"
"version": "2.8.2"
},
"translations": {
"addr_mismatch__contact_support_at": "Kontaktiere den Trezor Support unter",
@ -519,7 +519,6 @@
"passphrase__access_wallet": "Passphr. Wall. öffnen?",
"passphrase__always_on_device": "Deine Passphrase immer auf dem Trezor eingeben?",
"passphrase__from_host_not_shown": "Passphrase vom Host wird verwendet, wegen Geräteeinstellungen aber nicht angezeigt.",
"passphrase__wallet": "Passphrase Wallet",
"passphrase__hide": "Passphrase vom Host ausblenden?",
"passphrase__next_screen_will_show_passphrase": "Der nächste Bildschirm zeigt deine Passphrase.",
"passphrase__please_enter": "Gib deine Passphrase ein.",
@ -532,6 +531,7 @@
"passphrase__title_source": "Passphrasen-quelle",
"passphrase__turn_off": "Passphrasenschutz deaktivieren?",
"passphrase__turn_on": "Passphrasenschutz aktivieren?",
"passphrase__wallet": "Passphrase Wallet",
"pin__cancel_description": "Ohne PIN fortfahren",
"pin__cancel_info": "Ohne PIN kann jeder auf dieses Gerät zugreifen.",
"pin__cancel_setup": "PIN-Einrichtung abbrechen",

View File

@ -1,7 +1,7 @@
{
"header": {
"language": "en-US",
"version": "2.8.1"
"version": "2.8.2"
},
"translations": {
"addr_mismatch__contact_support_at": "Please contact Trezor support at",
@ -81,9 +81,9 @@
"bitcoin__unverified_external_inputs": "The transaction contains unverified external inputs.",
"bitcoin__valid_signature": "The signature is valid.",
"bitcoin__voting_rights": "Voting rights to:",
"brightness__title": "Display brightness",
"brightness__change_title": "Change display brightness",
"brightness__changed_title": "Display brightness changed",
"brightness__title": "Display brightness",
"buttons__abort": "Abort",
"buttons__access": "Access",
"buttons__again": "Again",
@ -360,6 +360,9 @@
"haptic_feedback__title": "Haptic feedback",
"homescreen__click_to_connect": "Click to Connect",
"homescreen__click_to_unlock": "Click to Unlock",
"homescreen__set_default": "Change wallpaper to default image?",
"homescreen__settings_subtitle": "Settings",
"homescreen__settings_title": "Homescreen",
"homescreen__title_backup_failed": "Backup failed",
"homescreen__title_backup_needed": "Backup needed",
"homescreen__title_coinjoin_authorized": "Coinjoin authorized",
@ -368,9 +371,6 @@
"homescreen__title_pin_not_set": "PIN not set",
"homescreen__title_seedless": "Seedless",
"homescreen__title_set": "Change wallpaper",
"homescreen__settings_title": "Homescreen",
"homescreen__settings_subtitle": "Settings",
"homescreen__set_default": "Change wallpaper to default image?",
"inputs__back": "BACK",
"inputs__cancel": "CANCEL",
"inputs__delete": "DELETE",
@ -389,8 +389,8 @@
"instructions__learn_more": "Learn more",
"instructions__shares_continue_with_x_template": "Continue with Share #{0}",
"instructions__shares_start_with_1": "Start with share #1",
"instructions__swipe_up": "Swipe up",
"instructions__swipe_horizontally": "Swipe horizontally",
"instructions__swipe_up": "Swipe up",
"instructions__tap_to_confirm": "Tap to confirm",
"instructions__tap_to_start": "Tap to start",
"joint__title": "Joint transaction",
@ -490,7 +490,6 @@
"passphrase__access_wallet": "Access passphrase wallet?",
"passphrase__always_on_device": "Always enter your passphrase on Trezor?",
"passphrase__from_host_not_shown": "Passphrase provided by host will be used but will not be displayed due to the device settings.",
"passphrase__wallet": "Passphrase wallet",
"passphrase__hide": "Hide passphrase coming from host?",
"passphrase__next_screen_will_show_passphrase": "The next screen shows your passphrase.",
"passphrase__please_enter": "Please enter your passphrase.",
@ -503,6 +502,7 @@
"passphrase__title_source": "Passphrase source",
"passphrase__turn_off": "Turn off passphrase protection?",
"passphrase__turn_on": "Turn on passphrase protection?",
"passphrase__wallet": "Passphrase wallet",
"pin__cancel_description": "Continue without PIN",
"pin__cancel_info": "Without a PIN, anyone can access this device.",
"pin__cancel_setup": "Cancel PIN setup",
@ -859,8 +859,8 @@
"tutorial__did_you_know": "Did you know?",
"tutorial__exit": "Exit tutorial",
"tutorial__first_wallet": "The Trezor Model One, created in 2013,\nwas the world's first hardware wallet.",
"tutorial__lets_begin": "Learn how to use and navigate this device with ease.",
"tutorial__get_started": "Get started!",
"tutorial__lets_begin": "Learn how to use and navigate this device with ease.",
"tutorial__menu": "Find context-specific actions and options in the menu.",
"tutorial__middle_click": "Press both left and right at the same\ntime to confirm.",
"tutorial__press_and_hold": "Press and hold the right button to\napprove important operations.",
@ -870,11 +870,11 @@
"tutorial__scroll_down": "Press right to scroll down to read all content when text doesn't fit on one screen.\n\rPress left to scroll up.",
"tutorial__sure_you_want_skip": "Are you sure you\nwant to skip the tutorial?",
"tutorial__swipe_up_and_down": "Swipe up & down\nto move through screens.",
"tutorial__title_easy_navigation": "Easy navigation",
"tutorial__title_handy_menu": "Handy menu",
"tutorial__title_hello": "Hello",
"tutorial__title_hold": "Hold to confirm important actions",
"tutorial__title_lets_begin": "Let's begin",
"tutorial__title_easy_navigation": "Easy navigation",
"tutorial__title_screen_scroll": "Screen scroll",
"tutorial__title_skip": "Skip tutorial",
"tutorial__title_tutorial_complete": "Tutorial complete",
@ -942,13 +942,13 @@
"words__sign": "Sign",
"words__signer": "Signer",
"words__title_check": "Check",
"words__title_done": "Done",
"words__title_group": "Group",
"words__title_information": "Information",
"words__title_remember": "Remember",
"words__title_share": "Share",
"words__title_shares": "Shares",
"words__title_success": "Success",
"words__title_done": "Done",
"words__title_summary": "Summary",
"words__title_threshold": "Threshold",
"words__try_again": "Try again.",

View File

@ -33,7 +33,7 @@
},
"header": {
"language": "es-ES",
"version": "2.8.1"
"version": "2.8.2"
},
"translations": {
"addr_mismatch__contact_support_at": "Contacta con atención al cliente de Trezor en",
@ -519,7 +519,6 @@
"passphrase__access_wallet": "¿Ir al monedero frase contr.?",
"passphrase__always_on_device": "¿Introduces siempre la frase de contraseña en Trezor?",
"passphrase__from_host_not_shown": "Se usará la frase de contraseña dada por el host, pero no se verá debido a la configuración.",
"passphrase__wallet": "Monedero frase contr.",
"passphrase__hide": "¿Ocultar la frase de contraseña del host?",
"passphrase__next_screen_will_show_passphrase": "La siguiente pantalla muestra la frase de contraseña.",
"passphrase__please_enter": "Escribe la frase de contraseña.",
@ -532,6 +531,7 @@
"passphrase__title_source": "Origen frase contr.",
"passphrase__turn_off": "¿Desactivar la protección por frase de contraseña?",
"passphrase__turn_on": "¿Activar la protección por frase de contraseña?",
"passphrase__wallet": "Monedero frase contr.",
"pin__cancel_description": "Continuar sin PIN",
"pin__cancel_info": "Sin un PIN, cualquiera puede acceder al dispositivo.",
"pin__cancel_setup": "Cancelar configuración de PIN",

View File

@ -33,7 +33,7 @@
},
"header": {
"language": "fr-FR",
"version": "2.8.1"
"version": "2.8.2"
},
"translations": {
"addr_mismatch__contact_support_at": "Contactez l'assistance Trezor à l'adr.",
@ -519,7 +519,6 @@
"passphrase__access_wallet": "Accès portef. phrase secr. ?",
"passphrase__always_on_device": "Saisissez toujours votre phrase secrète sur Trezor ?",
"passphrase__from_host_not_shown": "La phrase secrète fournie par l'hôte sera utilisée, mais pas affichée en raison des paramètres du disp.",
"passphrase__wallet": "Portef. phrase secr.",
"passphrase__hide": "Masq. phrase secrète de l'hôte ?",
"passphrase__next_screen_will_show_passphrase": "L'écran suivant affiche votre phrase secrète.",
"passphrase__please_enter": "Saisissez votre phrase secrète.",
@ -532,6 +531,7 @@
"passphrase__title_source": "Source phrase secr.",
"passphrase__turn_off": "Désactiver la prot. par phrase secrète ?",
"passphrase__turn_on": "Activer la prot. par phrase secrète ?",
"passphrase__wallet": "Portef. phrase secr.",
"pin__cancel_description": "Continuer sans PIN",
"pin__cancel_info": "Sans PIN, tout le monde peut accéder à ce dispositif.",
"pin__cancel_setup": "Annuler la configuration du PIN",

View File

@ -1,8 +1,8 @@
{
"current": {
"merkle_root": "61467ac92c2678b624b2d9ab6d37d4a0ec01a44b5542261c19406cc674d71142",
"datetime": "2024-07-30T14:58:52.679738",
"commit": "39300036083657236e241da6f57f9a78269814e7"
"merkle_root": "4bb6cb2b14bb500c65018e01418d772f6753d41f3fb44879a033b0dc80591a48",
"datetime": "2024-08-13T11:39:28.337426",
"commit": "43ed9529001e29070afce32160e2facb23350722"
},
"history": [
{

View File

@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it,
it is just a single binary file that you can execute directly.
**Are you looking for a Trezor T emulator? This is most likely it.**
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317)
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456)
Build of our cryptographic library, which is then incorporated into the other builds.
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485)
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559)
Regular version (not only Bitcoin) of above.
**Are you looking for a Trezor One emulator? This is most likely it.**
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573)
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618)
Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses.
Bitcoin-only version.
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634)
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635)
---
## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml)

View File

@ -191,6 +191,9 @@ void fsm_sendFailure(FailureType code, const char *text)
case FailureType_Failure_InvalidSession:
text = _("Invalid session");
break;
case FailureType_Failure_ThpUnallocatedSession:
text = _("Unallocated session");
break;
case FailureType_Failure_FirmwareError:
text = _("Firmware error");
break;

View File

@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro
EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \
UnlockBootloader AuthenticateDevice AuthenticityProof \
Solana StellarClaimClaimableBalanceOp \
ChangeLanguage TranslationDataRequest TranslationDataAck \
ChangeLanguage TranslationDataRequest TranslationDataAck Thp \
SetBrightness DebugLinkOptigaSetSecMax \
ifeq ($(BITCOIN_ONLY), 1)

View File

@ -0,0 +1 @@
../../vendor/trezor-common/protob/messages-thp.proto

View File

@ -65,11 +65,11 @@ def send_bytes(
raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport()
transport.begin_session()
transport.deprecated_begin_session()
transport.write(message_type, message_data)
response_type, response_data = transport.read()
transport.end_session()
transport.deprecated_end_session()
click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}")

Some files were not shown because too many files have changed in this diff Show More