1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-12 15:42:40 +00:00

feat: new THP

This commit is contained in:
M1nd3r 2024-11-15 16:06:55 +01:00
parent 6cbf5e4064
commit aaaeb3abca
272 changed files with 18371 additions and 6209 deletions

View File

@ -307,6 +307,7 @@ core unix frozen debug build:
needs: []
variables:
PYOPT: "0"
THP: "1"
script:
- $NIX_SHELL --run "poetry run make -C core build_unix_frozen"
artifacts:

View File

@ -39,6 +39,8 @@ message Failure {
Failure_PinMismatch = 12;
Failure_WipeCodeMismatch = 13;
Failure_InvalidSession = 14;
Failure_ThpUnallocatedSession=15;
Failure_InvalidProtocol=16;
Failure_FirmwareError = 99;
}
}

View File

@ -110,6 +110,8 @@ message DebugLinkGetState {
// trezor-core only - wait until current layout changes
// changed in 2.6.4: multiple wait types instead of true/false.
optional DebugWaitType wait_layout = 3 [default=IMMEDIATE];
// THP only - it is used to get information from specified channel
optional bytes thp_channel_id=4;
}
/**
@ -130,6 +132,9 @@ message DebugLinkState {
optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow
optional management.BackupType mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39)
repeated string tokens = 13; // current layout represented as a list of string tokens
optional uint32 thp_pairing_code_entry_code = 14;
optional bytes thp_pairing_code_qr_code = 15;
optional bytes thp_pairing_code_nfc_unidirectional = 16;
}
/**

View File

@ -9,6 +9,218 @@ import "options.proto";
option (include_in_bitcoin_only) = true;
/**
* Mapping between Trezor wire identifier (uint) and a Thp protobuf message
*/
enum ThpMessageType {
reserved 0 to 999; // Values reserved by other messages, see messages.proto
ThpMessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true];
ThpMessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true];
ThpMessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true];
ThpMessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true];
ThpMessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true];
ThpMessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true];
ThpMessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true];
ThpMessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true];
ThpMessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true];
ThpMessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true];
ThpMessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true];
ThpMessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true];
ThpMessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true];
ThpMessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true];
ThpMessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true];
ThpMessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true];
ThpMessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true];
ThpMessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true];
reserved 1100 to 2147483647; // Values reserved by other messages, see messages.proto
}
/**
* Numeric identifiers of pairing methods.
* @embed
*/
enum ThpPairingMethod {
NoMethod = 1; // Trust without MITM protection.
CodeEntry = 2; // User types code diplayed on Trezor into the host application.
QrCode = 3; // User scans code displayed on Trezor into host application.
NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
}
/**
* @embed
*/
message ThpDeviceProperties {
optional string internal_model = 1; // Internal model name e.g. "T2B1".
optional uint32 model_variant = 2; // Encodes the device properties such as color.
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
}
/**
* @embed
*/
message ThpHandshakeCompletionReqNoisePayload {
optional bytes host_pairing_credential = 1; // Host's pairing credential
repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host
}
/**
* Request: Ask device for a new session with given passphrase.
* @start
* @next ThpNewSession
*/
message ThpCreateNewSession{
optional string passphrase = 1;
optional bool on_device = 2; // User wants to enter passphrase on the device
optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only
}
/**
* Response: Contains session_id of the newly created session.
* @end
*/
message ThpNewSession{
optional uint32 new_session_id = 1;
}
/**
* Request: Start pairing process.
* @start
* @next ThpCodeEntryCommitment
* @next ThpPairingPreparationsFinished
*/
message ThpStartPairingRequest{
optional string host_name = 1; // Human-readable host name
}
/**
* Response: Pairing is ready for user input / OOB communication.
* @next ThpCodeEntryCpace
* @next ThpQrCodeTag
* @next ThpNfcUnidirectionalTag
*/
message ThpPairingPreparationsFinished{
}
/**
* Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment.
* @next ThpCodeEntryChallenge
*/
message ThpCodeEntryCommitment {
optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret
}
/**
* Response: Host responds to Trezor's Code Entry commitment with a challenge.
* @next ThpPairingPreparationsFinished
*/
message ThpCodeEntryChallenge {
optional bytes challenge = 1; // host's random 32-byte challenge
}
/**
* Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor.
* @next ThpCodeEntryCpaceTrezor
*/
message ThpCodeEntryCpaceHost {
optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key
}
/**
* Response: Trezor continues with the CPACE protocol.
* @next ThpCodeEntryTag
*/
message ThpCodeEntryCpaceTrezor {
optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key
}
/**
* Response: Host continues with the CPACE protocol.
* @next ThpCodeEntrySecret
*/
message ThpCodeEntryTag {
optional bytes tag = 2; // SHA-256 of shared secret
}
/**
* Response: Trezor finishes the CPACE protocol.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCodeEntrySecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected QR Code pairing option. Host sends a QR Tag.
* @next ThpQrCodeSecret
*/
message ThpQrCodeTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the QR secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpQrCodeSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag.
* @next ThpNfcUnidirectionalSecret
*/
message ThpNfcUnidirectionalTag {
optional bytes tag = 1; // SHA-256 of shared secret
}
/**
* Response: Trezor sends the Unidirectioal NFC secret.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpNfcUnidirectionalSecret {
optional bytes secret = 1; // Trezor's secret
}
/**
* Request: Host requests issuance of a new pairing credential.
* @start
* @next ThpCredentialResponse
*/
message ThpCredentialRequest {
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake.
}
/**
* Response: Trezor issues a new pairing credential.
* @next ThpCredentialRequest
* @next ThpEndRequest
*/
message ThpCredentialResponse {
optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake.
optional bytes credential = 2; // The pairing credential issued by the Trezor to the host.
}
/**
* Request: Host requests transition to the encrypted traffic phase.
* @start
* @next ThpEndResponse
*/
message ThpEndRequest {}
/**
* Response: Trezor approves transition to the encrypted traffic phase
* @end
*/
message ThpEndResponse {}
/**
* Only for internal use.
* @embed

View File

@ -37,6 +37,10 @@ The convention to achieve this is as follows:
optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode
optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader
optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode
optional bool channel_in = 50009;
optional bool channel_out = 50010;
optional bool pairing_in = 50011;
optional bool pairing_out = 50012;
optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build
// (messages not marked bitcoin_only will be EXCLUDED)

View File

@ -62,6 +62,7 @@ INT_TYPES = (
)
MESSAGE_TYPE_ENUM = "MessageType"
THP_MESSAGE_TYPE_ENUM = "ThpMessageType"
LengthDelimited = c.Struct(
"len" / c.VarInt,
@ -239,6 +240,9 @@ class ProtoMessage:
@classmethod
def from_message(cls, descriptor: "Descriptor", message):
message_type = find_by_name(descriptor.message_type_enum.value, message.name)
thp_message_type = None
if not isinstance(descriptor.thp_message_type_enum,tuple):
thp_message_type = find_by_name(descriptor.thp_message_type_enum.value, message.name)
# use extensions set on the message_type entry (if any)
extensions = descriptor.get_extensions(message_type)
# override with extensions set on the message itself
@ -248,6 +252,8 @@ class ProtoMessage:
wire_type = extensions["wire_type"]
elif message_type is not None:
wire_type = message_type.number
elif thp_message_type is not None:
wire_type = thp_message_type.number
else:
wire_type = None
@ -351,10 +357,13 @@ class Descriptor:
]
logging.debug(f"found {len(self.files)} bitcoin-only files")
# find message_type enum
# find message_type and thp_message_type enum
top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files)
self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ())
top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files)
self.thp_message_type_enum = find_by_name(top_level_enums, THP_MESSAGE_TYPE_ENUM, ())
self.convert_enum_value_names(self.message_type_enum)
self.convert_enum_value_names(self.thp_message_type_enum)
# find messages and enums
self.messages = []
@ -423,6 +432,8 @@ class Descriptor:
self._nested_types_from_message(nested.orig)
def convert_enum_value_names(self, enum):
if isinstance(enum,tuple):
return
for value in enum.value:
value.name = strip_enum_prefix(enum.name, value.name)
@ -558,6 +569,8 @@ class RustBlobRenderer:
enums = []
cursor = 0
for enum in sorted(self.descriptor.enums, key=lambda e: e.name):
if enum.name == "MessageType":
continue
self.enum_map[enum.name] = cursor
enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value))
enums.append(enum_blob)

View File

@ -309,6 +309,12 @@ build_unix_frozen: templates build_cross ## build unix port with frozen modules
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 \
BENCHMARK="$(BENCHMARK)"
build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0")
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 BENCHMARK="$(BENCHMARK)"
build_unix_debug: templates ## build unix port
$(SCONS) --max-drift=1 CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \

View File

@ -561,6 +561,10 @@ if FROZEN:
] if not EVERYTHING else []
))
if not THP or PYOPT == '0':
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',

View File

@ -629,6 +629,10 @@ if FROZEN:
] if not EVERYTHING else []
))
if not THP or PYOPT == '0':
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
if THP:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',

View File

@ -1731,7 +1731,7 @@ pub static mp_module_trezorui2: Module = obj_module! {
/// """Calls drop on contents of the root component."""
///
/// class UiResult:
/// """Result of an UI operation."""
/// """Result of a UI operation."""
/// pass
///
/// mock:global

File diff suppressed because it is too large Load Diff

View File

@ -47,6 +47,12 @@ storage
import storage
storage.cache
import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common
import storage.common
storage.debug
@ -201,12 +207,20 @@ trezor.utils
import trezor.utils
trezor.wire
import trezor.wire
trezor.wire.codec_v1
import trezor.wire.codec_v1
trezor.wire.codec
import trezor.wire.codec
trezor.wire.codec.codec_context
import trezor.wire.codec.codec_context
trezor.wire.codec.codec_v1
import trezor.wire.codec.codec_v1
trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.message_handler
import trezor.wire.message_handler
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.workflow
import trezor.workflow
apps
@ -309,6 +323,8 @@ apps.common.backup
import apps.common.backup
apps.common.backup_types
import apps.common.backup_types
apps.common.cache
import apps.common.cache
apps.common.cbor
import apps.common.cbor
apps.common.coininfo
@ -401,10 +417,52 @@ apps.workflow_handlers
import apps.workflow_handlers
if utils.USE_THP:
trezor.enums.ThpMessageType
import trezor.enums.ThpMessageType
trezor.enums.ThpPairingMethod
import trezor.enums.ThpPairingMethod
trezor.wire.thp
import trezor.wire.thp
trezor.wire.thp.alternating_bit_protocol
import trezor.wire.thp.alternating_bit_protocol
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.cpace
import trezor.wire.thp.cpace
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_main
import trezor.wire.thp.thp_main
trezor.wire.thp.transmission_loop
import trezor.wire.thp.transmission_loop
trezor.wire.thp.writer
import trezor.wire.thp.writer
apps.thp
import apps.thp
apps.thp.create_new_session
import apps.thp.create_new_session
apps.thp.credential_manager
import apps.thp.credential_manager
apps.thp.pairing
import apps.thp.pairing
if not utils.BITCOIN_ONLY:
trezor.enums.BinanceOrderSide

View File

@ -1,23 +1,20 @@
from typing import Iterable
import storage.cache as storage_cache
from storage.cache_common import (
APP_COMMON_AUTHORIZATION_DATA,
APP_COMMON_AUTHORIZATION_TYPE,
)
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire import context
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool:
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None
return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
def set(auth_message: protobuf.MessageType) -> None:
@ -29,27 +26,27 @@ def set(auth_message: protobuf.MessageType) -> None:
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, stored_auth_type)
def is_set_any_session(auth_type: MessageType) -> bool:
return auth_type in storage_cache.get_int_all_sessions(
return auth_type in context.cache_get_int_all_sessions(
APP_COMMON_AUTHORIZATION_TYPE
)
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None:
return ()
@ -57,5 +54,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None:
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE)
context.cache_delete(APP_COMMON_AUTHORIZATION_DATA)

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING
from trezor import utils
from trezor.crypto import bip32
from trezor.wire import DataError
@ -172,7 +173,10 @@ async def get_keychain(
) -> Keychain:
from .seed import get_seed
seed = await get_seed()
if not utils.USE_THP:
pass
# try to ask for passphrase here
seed = get_seed()
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain

View File

@ -1,83 +1,122 @@
from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device
from trezor import utils
from trezor.wire import DataError
_MAX_PASSPHRASE_LEN = const(50)
if TYPE_CHECKING:
from trezor.messages import ThpCreateNewSession
def is_enabled() -> bool:
return storage_device.is_passphrase_enabled()
async def get() -> str:
from trezor import workflow
async def get_passphrase(msg: ThpCreateNewSession) -> str:
if not is_enabled():
return ""
if msg.on_device or storage_device.get_passphrase_always_on_device():
passphrase = await _get_on_device()
else:
workflow.close_others() # request exclusive UI access
if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
passphrase = msg.passphrase or ""
if passphrase:
await _handle_displaying_passphrase_from_host(passphrase)
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _request_on_host() -> str:
from trezor import TR
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host()
request = PassphraseRequest()
ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None:
raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
)
# non-empty passphrase
if passphrase:
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _get_on_device() -> str:
from trezor import workflow
from trezor.ui.layouts import request_passphrase_on_device
workflow.close_others() # request exclusive UI access
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
return passphrase
async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
from trezor import TR
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)
if not utils.USE_THP:
async def get() -> str:
from trezor import workflow
if not is_enabled():
return ""
else:
workflow.close_others() # request exclusive UI access
if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
)
return passphrase
async def _request_on_host() -> str:
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host()
request = PassphraseRequest()
ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None:
raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
)
# non-empty passphrase
if passphrase:
await _handle_displaying_passphrase_from_host(passphrase)
return passphrase

View File

@ -1,3 +1,6 @@
from trezor.wire import message_handler
from trezor.wire.protocol_common import Context
if not __debug__:
from trezor.utils import halt
@ -70,9 +73,7 @@ if __debug__:
"layout deadlock detected (did you send a ButtonAck?)"
)
async def return_layout_change(
ctx: wire.context.Context, detect_deadlock: bool = False
) -> None:
async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None:
# set up the wait
storage.layout_watcher = True
@ -244,7 +245,11 @@ if __debug__:
# If no exception was raised, the layout did not shut down. That means that it
# just updated itself. The update is already live for the caller to retrieve.
def _state() -> DebugLinkState:
def _state(
thp_pairing_code_entry_code: int | None = None,
thp_pairing_code_qr_code: bytes | None = None,
thp_pairing_code_nfc_unidirectional: bytes | None = None,
) -> DebugLinkState:
from trezor.messages import DebugLinkState
from apps.common import mnemonic, passphrase
@ -263,13 +268,45 @@ if __debug__:
passphrase_protection=passphrase.is_enabled(),
reset_entropy=storage.reset_internal_entropy,
tokens=tokens,
thp_pairing_code_entry_code=thp_pairing_code_entry_code,
thp_pairing_code_qr_code=thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional,
)
async def dispatch_DebugLinkGetState(
msg: DebugLinkGetState,
) -> DebugLinkState | None:
thp_pairing_code_entry_code: int | None = None
thp_pairing_code_qr_code: bytes | None = None
thp_pairing_code_nfc_unidirectional: bytes | None = None
if utils.USE_THP and msg.thp_channel_id is not None:
channel_id = int.from_bytes(msg.thp_channel_id, "big")
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_main import _CHANNELS
channel: Channel | None = None
ctx: PairingContext | None = None
try:
channel = _CHANNELS[channel_id]
ctx = channel.connection_context
except KeyError:
pass
if ctx is not None and isinstance(ctx, PairingContext):
thp_pairing_code_entry_code = ctx.display_data.code_code_entry
thp_pairing_code_qr_code = ctx.display_data.code_qr_code
thp_pairing_code_nfc_unidirectional = (
ctx.display_data.code_nfc_unidirectional
)
if msg.wait_layout == DebugWaitType.IMMEDIATE:
return _state()
return _state(
thp_pairing_code_entry_code,
thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional,
)
assert DEBUG_CONTEXT is not None
if msg.wait_layout == DebugWaitType.NEXT_LAYOUT:
@ -280,7 +317,11 @@ if __debug__:
if not layout_is_ready():
return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True)
else:
return _state()
return _state(
thp_pairing_code_entry_code,
thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional,
)
async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
if msg.target_directory:
@ -356,11 +397,12 @@ if __debug__:
async def handle_session(iface: WireInterface) -> None:
from trezor import protobuf, wire
from trezor.wire import codec_v1, context
from trezor.wire.codec import codec_v1
from trezor.wire.codec.codec_context import CodecContext
global DEBUG_CONTEXT
DEBUG_CONTEXT = ctx = context.Context(iface, WIRE_BUFFER_DEBUG)
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
if storage.layout_watcher:
try:
@ -391,7 +433,7 @@ if __debug__:
)
if msg.type not in WORKFLOW_HANDLERS:
await ctx.write(wire.unexpected_message())
await ctx.write(message_handler.unexpected_message())
continue
elif req_type is None:
@ -402,7 +444,7 @@ if __debug__:
await ctx.write(Success())
continue
req_msg = wire.wrap_protobuf_load(msg.data, req_type)
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
try:
res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg)
except Exception as exc:

View File

@ -5,10 +5,11 @@ if TYPE_CHECKING:
async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto import random
from trezor.messages import Nonce
from trezor.wire.context import cache_set
nonce = random.bytes(32)
cache.set(cache.APP_COMMON_NONCE, nonce)
cache_set(APP_COMMON_NONCE, nonce)
return Nonce(nonce=nonce)

View File

@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
boot_args = None
ctx = get_context()
await ctx.write(Success(message="Rebooting"))
await ctx.write_force(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
# reboot to the bootloader, pass the firmware header hash if any

View File

@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
from trezor import TR, config, wire, workflow
from trezor.enums import BackupType, ButtonRequestType
from trezor.ui.layouts import confirm_action, confirm_reset_device
from trezor.wire.context import try_get_ctx_ids
from apps.common import mnemonic
from apps.common.request_pin import (
@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
if recovery_type == RecoveryType.NormalRecovery:
await confirm_reset_device(TR.recovery__title_recover, recovery=True)
# wipe storage to make sure the device is in a clear state
storage.reset()
# wipe storage to make sure the device is in a clear state (except protocol cache)
storage.reset(excluded=try_get_ctx_ids())
# set up pin if requested
if msg.pin_protection:

View File

@ -3,8 +3,9 @@ from typing import TYPE_CHECKING
import storage.device as storage_device
import storage.recovery as storage_recovery
import storage.recovery_shares as storage_recovery_shares
from trezor import TR, wire
from trezor import TR, utils, wire
from trezor.messages import Success
from trezor.wire import message_handler
from apps.common import backup_types
@ -38,18 +39,26 @@ async def recovery_process() -> Success:
recovery_type = storage_recovery.get_type()
wire.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
if utils.USE_THP:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.GetFeatures,
MessageType.EndSession,
)
else:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
try:
return await _continue_recovery_process()
except recover.RecoveryAborted:
storage_recovery.end_progress()
backup.deactivate_repeated_backup()
if recovery_type == RecoveryType.NormalRecovery:
storage.wipe()
from trezor.wire.context import try_get_ctx_ids
storage.wipe(excluded=try_get_ctx_ids())
raise wire.ActionCancelled
@ -59,11 +68,17 @@ async def _continue_repeated_backup() -> None:
from apps.common import backup
from apps.management.backup_device import perform_backup
wire.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
if utils.USE_THP:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.GetFeatures,
MessageType.EndSession,
)
else:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
try:
await perform_backup(is_repeated_backup=True)

View File

@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success:
prompt_backup,
show_wallet_created_success,
)
from trezor.wire.context import call
from trezor.wire.context import call, try_get_ctx_ids
from apps.common.request_pin import request_pin_confirm
@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success:
# Rendering empty loader so users do not feel a freezing screen
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
# wipe storage to make sure the device is in a clear state
storage.reset()
# wipe storage to make sure the device is in a clear state (except protocol cache)
storage.reset(excluded=try_get_ctx_ids())
# request and set new PIN
if msg.pin_protection:
@ -121,7 +121,7 @@ async def reset_device(msg: ResetDevice) -> Success:
if perform_backup:
await layout.show_backup_success()
return Success(message="Initialized")
return Success(message="Initialized") # TODO: Why "Initialized?"
async def _backup_bip39(mnemonic: str) -> None:

View File

@ -1,12 +1,19 @@
from typing import TYPE_CHECKING
from trezor.wire.context import get_context, try_get_ctx_ids
if TYPE_CHECKING:
from trezor.messages import Success, WipeDevice
from typing import NoReturn
from trezor.messages import WipeDevice
if __debug__:
from trezor import log
async def wipe_device(msg: WipeDevice) -> Success:
async def wipe_device(msg: WipeDevice) -> NoReturn:
import storage
from trezor import TR, config, translations
from trezor import TR, config, loop, translations
from trezor.enums import ButtonRequestType
from trezor.messages import Success
from trezor.pin import render_empty_loader
@ -26,16 +33,22 @@ async def wipe_device(msg: WipeDevice) -> Success:
br_code=ButtonRequestType.WipeDevice,
)
if __debug__:
log.debug(__name__, "Device wipe - start")
# start an empty progress screen so that the screen is not blank while waiting
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
# wipe storage
storage.wipe()
storage.wipe(excluded=try_get_ctx_ids())
# erase translations
translations.deinit()
translations.erase()
await get_context().write_force(Success(message="Device wiped"))
storage.wipe_cache()
# reload settings
reload_settings_from_storage()
return Success(message="Device wiped")
loop.clear()
if __debug__:
log.debug(__name__, "Device wipe - finished")

View File

@ -0,0 +1,49 @@
from trezor import log, loop
from trezor.enums import FailureType
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.wire.context import get_context
from trezor.wire.errors import ActionCancelled, DataError
from trezor.wire.thp import SessionState
async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
ctx = get_context()
# Assert that context `ctx` is ManagementSessionContext
from trezor.wire.thp.session_context import ManagementSessionContext
assert isinstance(ctx, ManagementSessionContext)
channel = ctx.channel
# Do not use `ctx` beyond this point, as it is techically
# allowed to change in between await statements
new_session = create_new_session(channel)
try:
await derive_and_store_roots(new_session, message)
except DataError as e:
return Failure(code=FailureType.DataError, message=e.message)
except ActionCancelled as e:
return Failure(code=FailureType.ActionCancelled, message=e.message)
# TODO handle other errors
new_session.set_session_state(SessionState.ALLOCATED)
channel.sessions[new_session.session_id] = new_session
loop.schedule(new_session.handle())
new_session_id: int = new_session.session_id
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
message.passphrase if message.passphrase is not None else "",
new_session.session_id,
str(channel.sessions),
)
return ThpNewSession(new_session_id=new_session_id)

View File

@ -7,7 +7,7 @@ from trezor.messages import (
ThpCredentialMetadata,
ThpPairingCredential,
)
from trezor.wire import wrap_protobuf_load
from trezor.wire import message_handler
if TYPE_CHECKING:
from apps.common.paths import Slip21Path
@ -72,7 +72,9 @@ def validate_credential(
"""
cred_auth_key = derive_cred_auth_key()
expected_type = protobuf.type_for_name("ThpPairingCredential")
credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type)
credential = message_handler.wrap_protobuf_load(
encoded_pairing_credential_message, expected_type
)
assert ThpPairingCredential.is_type_of(credential)
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,

View File

@ -0,0 +1,403 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import loop, protobuf
from trezor.crypto.hashlib import sha256
from trezor.enums import ThpMessageType, ThpPairingMethod
from trezor.messages import (
Cancel,
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpCredentialMetadata,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnidirectionalSecret,
ThpNfcUnidirectionalTag,
ThpPairingPreparationsFinished,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto
from trezor.wire.thp.pairing_context import PairingContext
from .credential_manager import issue_credential
if __debug__:
from trezor import log
if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
#
# Helpers - decorators
def check_state_and_log(
*allowed_states: ChannelState,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_state(context, *allowed_states)
if __debug__:
try:
log.debug(__name__, "started %s", f.__name__)
except AttributeError:
log.debug(
__name__,
"started a function that cannot be named, because it raises AttributeError, eg. closure",
)
return f(context, *args, **kwargs)
return inner
return decorator
def check_method_is_allowed(
pairing_method: ThpPairingMethod,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_method_is_allowed(context, pairing_method)
return f(context, *args, **kwargs)
return inner
return decorator
#
# Pairing handlers
@check_state_and_log(ChannelState.TP1)
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpStartPairingRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
ctx.host_name = message.host_name or ""
skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
if skip_pairing:
return await _end_pairing(ctx)
await _prepare_pairing(ctx)
await ctx.write(ThpPairingPreparationsFinished())
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await show_display_data(
ctx, _get_possible_pairing_methods_and_cancel(ctx)
)
if Cancel.is_type_of(response):
ctx.channel_ctx.clear()
raise SilentError("Action was cancelled by the Host")
# TODO disable NFC (if enabled)
response = await _handle_different_pairing_methods(ctx, response)
while ThpCredentialRequest.is_type_of(response):
response = await _handle_credential_request(ctx, response)
return await _handle_end_request(ctx, response)
async def _prepare_pairing(ctx: PairingContext) -> None:
if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
await _handle_code_entry_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.QrCode):
_handle_qr_code_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
_handle_nfc_unidirectional_is_included(ctx)
async def show_display_data(
ctx: PairingContext, expected_types: Container[int] = ()
) -> type[protobuf.MessageType]:
from trezorui2 import CANCELLED
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task.get_result())
result: type[protobuf.MessageType] = await race
if result is CANCELLED:
raise ActionCancelled
return result
@check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
commitment = sha256(ctx.secret).digest()
challenge_message = await ctx.call( # noqa: F841
ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
)
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
if not ThpCodeEntryChallenge.is_type_of(challenge_message):
raise UnexpectedMessage("Unexpected message")
if challenge_message.challenge is None:
raise Exception("Invalid message")
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(challenge_message.challenge)
sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8"))
code_code_entry_hash = sha_ctx.digest()
ctx.display_data.code_code_entry = (
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_qr_code_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8"))
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8"))
ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP3)
async def _handle_different_pairing_methods(
ctx: PairingContext, response: protobuf.MessageType
) -> protobuf.MessageType:
if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise UnexpectedMessage("Unexpected message")
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
from trezor.wire.thp.cpace import Cpace
# TODO check that ThpCodeEntryCpaceHost message is valid
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryCpaceHost)
if message.cpace_host_public_key is None:
raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
ctx.cpace = Cpace(
message.cpace_host_public_key,
ctx.channel_ctx.get_handshake_hash(),
)
assert ctx.display_data.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
ThpCodeEntryTag,
)
return await _handle_code_entry_tag(ctx, response)
@check_state_and_log(ChannelState.TP4)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryTag)
expected_tag = sha256(ctx.cpace.shared_secret).digest()
if expected_tag != message.tag:
print(
"expected code entry tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code entry shared secret:",
hexlify(ctx.cpace.shared_secret).decode(),
) # TODO remove after testing
raise ThpError("Unexpected Code Entry Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpCodeEntrySecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.QrCode)
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag)
assert ctx.display_data.code_qr_code is not None
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
if expected_tag != message.tag:
print(
"expected qr code tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code qr code tag:",
hexlify(ctx.display_data.code_qr_code).decode(),
) # TODO remove after testing
print(
"expected secret:", hexlify(ctx.secret).decode()
) # TODO remove after testing
raise ThpError("Unexpected QR Code Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpQrCodeSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpNfcUnidirectionalTag)
expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
if expected_tag != message.tag:
print(
"expected nfc tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected NFC Unidirectional Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
async def _handle_secret_reveal(
ctx: PairingContext,
msg: protobuf.MessageType,
) -> protobuf.MessageType:
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
return await ctx.call_any(
msg,
ThpMessageType.ThpCredentialRequest,
ThpMessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
ctx.secret
if not ThpCredentialRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
if message.host_static_pubkey is None:
raise Exception("Invalid message") # TODO change failure type
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
return await ctx.call_any(
ThpCredentialResponse(
trezor_static_pubkey=trezor_static_pubkey, credential=credential
),
ThpMessageType.ThpCredentialRequest,
ThpMessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpEndRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
return await _end_pairing(ctx)
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
#
# Helpers - checkers
def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
if ctx.channel_ctx.get_channel_state() not in allowed_states:
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel_ctx.selected_pairing_methods
#
# Helpers - getters
def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]:
r = _get_possible_pairing_methods(ctx)
mtype = Cancel.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
r = tuple(
_get_message_type_for_method(method)
for method in ctx.channel_ctx.selected_pairing_methods
)
if __debug__:
from trezor.messages import DebugLinkGetState
mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
return r
def _get_message_type_for_method(method: int) -> int:
if method is ThpPairingMethod.CodeEntry:
return ThpMessageType.ThpCodeEntryCpaceHost
if method is ThpPairingMethod.NFC_Unidirectional:
return ThpMessageType.ThpNfcUnidirectionalTag
if method is ThpPairingMethod.QrCode:
return ThpMessageType.ThpQrCodeTag
raise ValueError("Unexpected pairing method - no message type available")

View File

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezorio import WireInterface
from trezor.wire import Handler, Msg
@ -37,6 +35,13 @@ def _find_message_handler_module(msg_type: int) -> str:
if __debug__ and msg_type == MessageType.BenchmarkRun:
return "apps.benchmark.run"
if utils.USE_THP:
from trezor.enums import ThpMessageType
# thp management
if msg_type == ThpMessageType.ThpCreateNewSession:
return "apps.thp.create_new_session"
# management
if msg_type == MessageType.ResetDevice:
return "apps.management.reset_device"
@ -215,7 +220,7 @@ def _find_message_handler_module(msg_type: int) -> str:
raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
def find_registered_handler(msg_type: int) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
return workflow_handlers[msg_type]

View File

@ -29,7 +29,7 @@ if __debug__:
# trezor.pin imports trezor.utils
# We need it as an always-active module because trezor.pin.show_pin_timeout is used
# as an UI callback for storage, which can be invoked at any time
# as a UI callback for storage, which can be invoked at any time
import trezor.pin # noqa: F401
# === Prepare the USB interfaces first. Do not connect to the host yet.

View File

@ -1,11 +1,27 @@
# make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module
from typing import TYPE_CHECKING
from storage import cache, common, device
if TYPE_CHECKING:
from typing import Tuple
def wipe() -> None:
pass
def wipe(excluded: Tuple[bytes, bytes] | None) -> None:
"""
TODO REPHRASE SO THAT IT IS TRUE! Wipes the storage. Using `exclude_protocol=False` destroys the THP communication channel.
If the device should communicate after wipe, use `exclude_protocol=True` and clear cache manually later using
`wipe_cache()`.
"""
from trezor import config
config.wipe()
cache.clear_all(excluded)
def wipe_cache() -> None:
cache.clear_all()
@ -21,12 +37,12 @@ def init_unlocked() -> None:
common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True)
def reset() -> None:
def reset(excluded: Tuple[bytes, bytes] | None) -> None:
"""
Wipes storage but keeps the device id unchanged.
"""
device_id = device.get_device_id()
wipe()
wipe(excluded)
common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True)

View File

@ -0,0 +1,361 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
if TYPE_CHECKING:
from typing import Tuple
pass
if __debug__:
from trezor import log
pass
# THP specific constants
_MAX_CHANNELS_COUNT = const(10)
_MAX_SESSIONS_COUNT = const(20)
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2)
SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(0xFFFF)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
_MANAGEMENT_STATE = const(2)
MANAGEMENT_SESSION_ID = const(0)
class ThpDataCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.last_usage = 0
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
self.last_usage = 0
super().clear()
class ChannelCache(ThpDataCache):
def __init__(self) -> None:
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x00
self.fields = (
32, # CHANNEL_HANDSHAKE_HASH
32, # CHANNEL_KEY_RECEIVE
32, # CHANNEL_KEY_SEND
8, # CHANNEL_NONCE_RECEIVE
8, # CHANNEL_NONCE_SEND
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
) # Set state to UNALLOCATED
self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH)
self.state[:] = bytearray(_CHANNEL_STATE_LENGTH)
self.iface[:] = bytearray(1)
super().clear()
class SessionThpCache(ThpDataCache):
def __init__(self) -> None:
from trezor import utils
self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.session_id[:] = b""
super().clear()
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
cid_counter: int = 0
# Last-used counter
_usage_counter = 0
def initialize() -> None:
global _CHANNELS
global _SESSIONS
global cid_counter
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
from trezorcrypto import random
cid_counter = random.uniform(0xFFFE)
def get_new_channel(iface: bytes) -> ChannelCache:
if len(iface) != _WIRE_INTERFACE_LENGTH:
raise Exception("Invalid WireInterface (encoded) length")
new_cid = get_next_channel_id()
index = _get_next_channel_index()
# clear sessions from replaced channel
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
old_cid = _CHANNELS[index].channel_id
clear_sessions_with_channel_id(old_cid)
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface[:] = bytearray(iface)
return _CHANNELS[index]
def update_channel_last_used(channel_id: bytes) -> None:
for channel in _CHANNELS:
if channel.channel_id == channel_id:
channel.last_usage = _get_usage_counter_and_increment()
return
def update_session_last_used(channel_id: bytes, session_id: bytes) -> None:
for session in _SESSIONS:
if session.channel_id == channel_id and session.session_id == session_id:
session.last_usage = _get_usage_counter_and_increment()
update_channel_last_used(channel_id)
return
def get_all_allocated_channels() -> list[ChannelCache]:
_list: list[ChannelCache] = []
for channel in _CHANNELS:
if _get_channel_state(channel) != _UNALLOCATED_STATE:
_list.append(channel)
return _list
def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]:
if __debug__:
from trezor.utils import get_bytes_as_str
_list: list[SessionThpCache] = []
for session in _SESSIONS:
if _get_session_state(session) == _UNALLOCATED_STATE:
continue
if session.channel_id != channel_id:
continue
_list.append(session)
if __debug__:
log.debug(
__name__,
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
get_bytes_as_str(session.channel_id),
get_bytes_as_str(session.session_id),
)
return _list
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
if len(key) != KEY_LENGTH:
raise Exception("Invalid key length")
channel.host_ephemeral_pubkey = key
def get_new_session(channel: ChannelCache) -> SessionThpCache:
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
_SESSIONS[index] = SessionThpCache()
_SESSIONS[index].channel_id[:] = channel.channel_id
_SESSIONS[index].session_id[:] = new_sid
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
channel.last_usage = (
_get_usage_counter_and_increment()
) # increment also use of the channel so it does not get replaced
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
return _SESSIONS[index]
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
return _usage_counter
def _get_next_channel_index() -> int:
idx = _get_unallocated_channel_index()
if idx is not None:
return idx
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_next_session_index() -> int:
idx = _get_unallocated_session_index()
if idx is not None:
return idx
return get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_unallocated_session_index() -> int | None:
for i in range(_MAX_SESSIONS_COUNT):
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
return int.from_bytes(channel.state, "big")
def _get_session_state(session: SessionThpCache) -> int:
return int.from_bytes(session.state, "big")
def get_next_channel_id() -> bytes:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def get_next_session_id(channel: ChannelCache) -> bytes:
while True:
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
for session in _SESSIONS:
if session.channel_id == channel.channel_id:
if session.session_id == channel.session_id_counter:
return False
return True
def _is_cid_unique() -> bool:
global cid_counter
cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
for channel in _CHANNELS:
if channel.channel_id == cid_counter_bytes:
return False
return True
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
) -> int:
global _usage_counter
lru_counter = _usage_counter + 1
lru_item_index = 0
for i in range(max_count):
if list[i].last_usage < lru_counter:
lru_counter = list[i].last_usage
lru_item_index = i
return lru_item_index
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_sessions_with_channel_id(channel_id: bytes) -> None:
for session in _SESSIONS:
if session.channel_id == channel_id:
session.clear()
def clear_session(session: SessionThpCache) -> None:
for s in _SESSIONS:
if s.channel_id == session.channel_id and s.session_id == session.session_id:
session.clear()
def clear_all() -> None:
for session in _SESSIONS:
session.clear()
for channel in _CHANNELS:
channel.clear()
def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
cid, sid = excluded
for channel in _CHANNELS:
if channel.channel_id != cid:
channel.clear()
for session in _SESSIONS:
if session.channel_id != cid and session.session_id != sid:
session.clear()
else:
s_last_usage = session.last_usage
session.clear()
session.last_usage = s_last_usage
session.state = bytearray(_MANAGEMENT_STATE.to_bytes(1, "big"))
session.session_id[:] = bytearray(sid)
session.channel_id[:] = bytearray(cid)

View File

@ -16,4 +16,6 @@ NotInitialized = 11
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99

22
core/src/trezor/enums/ThpMessageType.py generated Normal file
View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033

View File

@ -0,0 +1,8 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

View File

@ -39,6 +39,8 @@ if TYPE_CHECKING:
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99
class ButtonRequestType(IntEnum):
@ -343,6 +345,32 @@ if TYPE_CHECKING:
Nay = 1
Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4
class MessageType(IntEnum):
Initialize = 0
Ping = 1

View File

@ -67,6 +67,8 @@ if TYPE_CHECKING:
from trezor.enums import StellarSignerType # noqa: F401
from trezor.enums import TezosBallotType # noqa: F401
from trezor.enums import TezosContractType # noqa: F401
from trezor.enums import ThpMessageType # noqa: F401
from trezor.enums import ThpPairingMethod # noqa: F401
from trezor.enums import WordRequestType # noqa: F401
class BenchmarkListNames(protobuf.MessageType):
@ -2863,11 +2865,13 @@ if TYPE_CHECKING:
class DebugLinkGetState(protobuf.MessageType):
wait_layout: "DebugWaitType"
thp_channel_id: "bytes | None"
def __init__(
self,
*,
wait_layout: "DebugWaitType | None" = None,
thp_channel_id: "bytes | None" = None,
) -> None:
pass
@ -2889,6 +2893,9 @@ if TYPE_CHECKING:
reset_word_pos: "int | None"
mnemonic_type: "BackupType | None"
tokens: "list[str]"
thp_pairing_code_entry_code: "int | None"
thp_pairing_code_qr_code: "bytes | None"
thp_pairing_code_nfc_unidirectional: "bytes | None"
def __init__(
self,
@ -2906,6 +2913,9 @@ if TYPE_CHECKING:
recovery_word_pos: "int | None" = None,
reset_word_pos: "int | None" = None,
mnemonic_type: "BackupType | None" = None,
thp_pairing_code_entry_code: "int | None" = None,
thp_pairing_code_qr_code: "bytes | None" = None,
thp_pairing_code_nfc_unidirectional: "bytes | None" = None,
) -> None:
pass
@ -6127,6 +6137,278 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
return isinstance(msg, cls)
class ThpDeviceProperties(protobuf.MessageType):
internal_model: "str | None"
model_variant: "int | None"
bootloader_mode: "bool | None"
protocol_version: "int | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
internal_model: "str | None" = None,
model_variant: "int | None" = None,
bootloader_mode: "bool | None" = None,
protocol_version: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
return isinstance(msg, cls)
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
host_pairing_credential: "bytes | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
host_pairing_credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
return isinstance(msg, cls)
class ThpCreateNewSession(protobuf.MessageType):
passphrase: "str | None"
on_device: "bool | None"
derive_cardano: "bool | None"
def __init__(
self,
*,
passphrase: "str | None" = None,
on_device: "bool | None" = None,
derive_cardano: "bool | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
return isinstance(msg, cls)
class ThpNewSession(protobuf.MessageType):
new_session_id: "int | None"
def __init__(
self,
*,
new_session_id: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
return isinstance(msg, cls)
class ThpStartPairingRequest(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
return isinstance(msg, cls)
class ThpPairingPreparationsFinished(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
return isinstance(msg, cls)
class ThpCodeEntryCommitment(protobuf.MessageType):
commitment: "bytes | None"
def __init__(
self,
*,
commitment: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
return isinstance(msg, cls)
class ThpCodeEntryChallenge(protobuf.MessageType):
challenge: "bytes | None"
def __init__(
self,
*,
challenge: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceHost(protobuf.MessageType):
cpace_host_public_key: "bytes | None"
def __init__(
self,
*,
cpace_host_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
cpace_trezor_public_key: "bytes | None"
def __init__(
self,
*,
cpace_trezor_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
return isinstance(msg, cls)
class ThpCodeEntryTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
return isinstance(msg, cls)
class ThpCodeEntrySecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
return isinstance(msg, cls)
class ThpQrCodeTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
return isinstance(msg, cls)
class ThpQrCodeSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
return isinstance(msg, cls)
class ThpCredentialRequest(protobuf.MessageType):
host_static_pubkey: "bytes | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
return isinstance(msg, cls)
class ThpCredentialResponse(protobuf.MessageType):
trezor_static_pubkey: "bytes | None"
credential: "bytes | None"
def __init__(
self,
*,
trezor_static_pubkey: "bytes | None" = None,
credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
return isinstance(msg, cls)
class ThpEndRequest(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
return isinstance(msg, cls)
class ThpEndResponse(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
return isinstance(msg, cls)
class ThpCredentialMetadata(protobuf.MessageType):
host_name: "str | None"

View File

@ -328,7 +328,7 @@ class Layout(Generic[T]):
def _paint(self) -> None:
"""Paint the layout and ensure that homescreen cache is properly invalidated."""
import storage.cache as storage_cache
import storage.cache_common as storage_cache
painted = self.layout.paint()
if painted:

View File

@ -35,6 +35,10 @@ from typing import TYPE_CHECKING
DISABLE_ANIMATION = 0
DISABLE_ENCRYPTION: bool = False
ALLOW_DEBUG_MESSAGES: bool = True
if __debug__:
if EMULATOR:
import uos
@ -111,6 +115,7 @@ def presize_module(modname: str, size: int) -> None:
if __debug__:
from ubinascii import hexlify
def mem_dump(filename: str) -> None:
from micropython import mem_info
@ -127,6 +132,9 @@ if __debug__:
else:
mem_info(True)
def get_bytes_as_str(a: bytes) -> str:
return hexlify(a).decode("utf-8")
def ensure(cond: bool, msg: str | None = None) -> None:
if not cond:

View File

@ -8,6 +8,12 @@ class Error(Exception):
self.message = message
class SilentError(Exception):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message
class UnexpectedMessage(Error):
def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message)

View File

@ -0,0 +1,184 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf, utils
from trezor.enums import ThpPairingMethod
from trezor.messages import ThpDeviceProperties
from ..protocol_common import WireError
if TYPE_CHECKING:
from enum import IntEnum
from trezor.wire import WireInterface
from typing_extensions import Self
else:
IntEnum = object
CODEC_V1 = const(0x3F)
HANDSHAKE_INIT_REQ = const(0x00)
HANDSHAKE_INIT_RES = const(0x01)
HANDSHAKE_COMP_REQ = const(0x02)
HANDSHAKE_COMP_RES = const(0x03)
ENCRYPTED = const(0x04)
ACK_MESSAGE = const(0x20)
CHANNEL_ALLOCATION_REQ = const(0x40)
_CHANNEL_ALLOCATION_RES = const(0x41)
_ERROR = const(0x42)
CONTINUATION_PACKET = const(0x80)
class ThpError(WireError):
pass
class ThpDecryptionError(ThpError):
pass
class ThpInvalidDataError(ThpError):
pass
class ThpUnallocatedSessionError(ThpError):
def __init__(self, session_id: int) -> None:
self.session_id = session_id
class ThpErrorType(IntEnum):
TRANSPORT_BUSY = 1
UNALLOCATED_CHANNEL = 2
DECRYPTION_FAILED = 3
INVALID_DATA = 4
class ChannelState(IntEnum):
UNALLOCATED = 0
TH1 = 1
TH2 = 2
TP1 = 3
TP2 = 4
TP3 = 5
TP4 = 6
TC1 = 7
ENCRYPTED_TRANSPORT = 8
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
MANAGEMENT = 2
class PacketHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
"""
Packs header information in the form of **intial** packet
into the provided buffer.
"""
ustruct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
"""
Packs header information in the form of **continuation** packet header
into the provided buffer.
"""
ustruct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
@classmethod
def get_error_header(cls, cid: int, length: int) -> Self:
"""
Returns header for protocol-level error messages.
"""
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length: int) -> Self:
"""
Returns header for allocation response handshake message.
"""
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
_DEFAULT_ENABLED_PAIRING_METHODS = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.QrCode,
ThpPairingMethod.NFC_Unidirectional,
]
def get_enabled_pairing_methods(
iface: WireInterface | None = None,
) -> list[ThpPairingMethod]:
"""
Returns pairing methods that are currently allowed by the device
with respect to the wire interface the host communicates on.
"""
import usb
methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy()
if iface is not None and iface is usb.iface_wire:
methods.append(ThpPairingMethod.NoMethod)
return methods
def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
# TODO define model variants
return ThpDeviceProperties(
pairing_methods=get_enabled_pairing_methods(iface),
internal_model=utils.INTERNAL_MODEL,
model_variant=0,
bootloader_mode=False,
protocol_version=2,
)
def get_encoded_device_properties(iface: WireInterface) -> bytes:
props = _get_device_properties(iface)
length = protobuf.encoded_length(props)
encoded_properties = bytearray(length)
protobuf.encode(encoded_properties, props)
return encoded_properties
def get_channel_allocation_response(
nonce: bytes, new_cid: bytes, iface: WireInterface
) -> bytes:
props_msg = get_encoded_device_properties(iface)
return nonce + new_cid + props_msg
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

View File

@ -0,0 +1,102 @@
from storage.cache_thp import ChannelCache
from trezor import log, utils
from trezor.wire.thp import ThpError
def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
"""
Checks if:
- an ACK message is expected
- the received ACK message acknowledges correct sequence number (bit)
"""
if not _is_ack_expected(cache):
return False
if not _has_ack_correct_sync_bit(cache, ack_bit):
return False
return True
def _is_ack_expected(cache: ChannelCache) -> bool:
is_expected: bool = not is_sending_allowed(cache)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected:
log.debug(__name__, "Received unexpected ACK message")
return is_expected
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
is_correct: bool = get_send_seq_bit(cache) == sync_bit
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct:
log.debug(__name__, "Received ACK message with wrong ack bit")
return is_correct
def is_sending_allowed(cache: ChannelCache) -> bool:
"""
Checks whether sending a message in the provided channel is allowed.
Note: Sending a message in a channel before receipt of ACK message for the previously
sent message (in the channel) is prohibited, as it can lead to desynchronization.
"""
return bool(cache.sync >> 7)
def get_send_seq_bit(cache: ChannelCache) -> int:
"""
Returns the sequential number (bit) of the next message to be sent
in the provided channel.
"""
return (cache.sync & 0x20) >> 5
def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
"""
Returns the (expected) sequential number (bit) of the next message
to be received in the provided channel.
"""
return (cache.sync & 0x40) >> 6
def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
"""
Set the flag whether sending a message in this channel is allowed or not.
"""
cache.sync &= 0x7F
if sending_allowed:
cache.sync |= 0x80
def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
"""
Set the expected sequential number (bit) of the next message to be received
in the provided channel
"""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
if seq_bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")
# set second bit to "seq_bit" value
cache.sync &= 0xBF
if seq_bit:
cache.sync |= 0x40
def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
if seq_bit not in (0, 1):
raise ThpError("Unexpected send seq bit")
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# set third bit to "seq_bit" value
cache.sync &= 0xDF
if seq_bit:
cache.sync |= 0x20
def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
"""
Set the sequential bit of the "next message to be send" to the opposite value,
i.e. 1 -> 0 and 0 -> 1
"""
_set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,405 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
from trezor import log, loop, protobuf, utils, workflow
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
from . import alternating_bit_protocol as ABP
from . import (
control_byte,
crypto,
interface_manager,
memory_manager,
received_message_handler,
)
from .checksum import CHECKSUM_LENGTH
from .transmission_loop import TransmissionLoop
from .writer import (
CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if __debug__:
from ubinascii import hexlify
from . import state_to_str
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable
from .pairing_context import PairingContext
from .session_context import GenericSessionContext
class Channel:
"""
THP protocol encrypted communication channel.
"""
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "channel initialization")
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, GenericSessionContext] = {}
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
self.transmission_loop: TransmissionLoop | None = None
self.handshake: crypto.Handshake | None = None
def clear(self) -> None:
clear_sessions_with_channel_id(self.channel_id)
self.channel_cache.clear()
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) get_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
return state
def get_handshake_hash(self) -> bytes:
h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
assert h is not None
return h
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_buffer: %s",
utils.get_bytes_as_str(self.channel_id),
type(self.buffer),
)
# CALLED BY THP_MAIN_LOOP
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) receive_packet",
utils.get_bytes_as_str(self.channel_id),
)
self._handle_received_packet(packet)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) self.buffer: %s",
utils.get_bytes_as_str(self.channel_id),
utils.get_bytes_as_str(self.buffer),
)
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message()
return received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
self.is_cont_packet_expected = True
else:
raise ThpError(
"Read more bytes than is the expected length of the message!"
)
return None
def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if control_byte.is_continuation(ctrl_byte):
return self._handle_cont_packet(packet)
return self._handle_init_packet(packet)
def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet",
utils.get_bytes_as_str(self.channel_id),
)
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
# TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
# On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
# in "memory_manager.select_buffer", because the session id is unknown (encrypted).
# if control_byte.is_encrypted_transport(ctrl_byte):
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet - payload len: %d",
utils.get_bytes_as_str(self.channel_id),
payload_length,
)
log.debug(
__name__,
"(cid: %s) handle_init_packet - buffer len: %d",
utils.get_bytes_as_str(self.channel_id),
len(self.buffer),
)
return self._buffer_packet_data(self.buffer, packet, 0)
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_cont_packet",
utils.get_bytes_as_str(self.channel_id),
)
if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring")
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
def _decrypt_single_packet_payload(
self, payload: utils.BufferType
) -> utils.BufferType:
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload
def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None:
noise_buffer = memoryview(self.buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
tag = self.buffer[
message_length
- CHECKSUM_LENGTH
- TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
if utils.DISABLE_ENCRYPTION:
is_tag_valid = tag == crypto.DUMMY_TAG
else:
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
assert key_receive is not None
assert nonce_receive is not None
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer before decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
is_tag_valid = crypto.dec(
noise_buffer, tag, key_receive, nonce_receive, b""
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer after decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Is decrypted tag valid? %s",
utils.get_bytes_as_str(self.channel_id),
str(is_tag_valid),
)
log.debug(
__name__,
"(cid: %s) Received tag: %s",
utils.get_bytes_as_str(self.channel_id),
(hexlify(tag).decode()),
)
log.debug(
__name__,
"(cid: %s) New nonce_receive: %i",
utils.get_bytes_as_str(self.channel_id),
nonce_receive + 1,
)
if not is_tag_valid:
raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
)
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
assert key_send is not None
assert nonce_send is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
) -> None:
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self) -> None:
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(
self,
msg: protobuf.MessageType,
session_id: int = 0,
force: bool = False,
) -> None:
if __debug__ and utils.EMULATOR:
log.debug(
__name__,
"(cid: %s) write message: %s\n%s",
utils.get_bytes_as_str(self.channel_id),
msg.MESSAGE_NAME,
utils.dump_protobuf(msg),
)
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
if task is not None:
await task
def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big")
length = len(msg_data) + CHECKSUM_LENGTH
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
def write_and_encrypt(
self, payload: bytes, force: bool = False
) -> Awaitable[None] | None:
payload_length = len(payload)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
if self.write_task_spawn is not None:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
if force:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "Writing FORCE message (without async or retransmission)."
)
return self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
)
return None
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
ABP.set_sending_allowed(self.channel_cache, False)
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid %s) write_encrypted_payload_loop",
utils.get_bytes_as_str(self.channel_id),
)
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
self.transmission_loop = TransmissionLoop(self, header, payload)
await self.transmission_loop.start()
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
# Let the main loop be restarted and clear loop, if there is no other
# workflow and the state is ENCRYPTED_TRANSPORT
if self._can_clear_loop():
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) clearing loop from channel",
utils.get_bytes_as_str(self.channel_id),
)
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

View File

@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
"""
Creates a new channel for the interface `iface` with the buffer `buffer`.
"""
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
"""
Returns all allocated channels from cache.
"""
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

View File

@ -0,0 +1,22 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,50 @@
from micropython import const
from . import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
ThpError,
)
_CONTINUATION_PACKET_MASK = const(0x80)
_ACK_MASK = const(0xF7)
_DATA_MASK = const(0xE7)
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise ThpError("Unexpected acknowledgement bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & _ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == ENCRYPTED
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,36 @@
from trezor.crypto import elligator2, random
from trezor.crypto.curve import curve25519
from trezor.crypto.hashlib import sha512
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.host_public_key: bytes = cpace_host_public_key
self.shared_secret: bytes
self.trezor_private_key: bytes
self.trezor_public_key: bytes
def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = elligator2.map_to_curve25519(pregenerator)
self.trezor_private_key = random.bytes(32)
self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
self.shared_secret = curve25519.multiply(
self.trezor_private_key, self.host_public_key
)

View File

@ -0,0 +1,211 @@
from micropython import const
from trezorcrypto import aesgcm, bip32, curve25519, hmac
from storage import device
from trezor import log, utils
from trezor.crypto.hashlib import sha256
from trezor.wire.thp import ThpDecryptionError
# The HARDENED flag is taken from apps.common.paths
# It is not imported to save on resources
HARDENED = const(0x8000_0000)
PUBKEY_LENGTH = const(32)
if utils.DISABLE_ENCRYPTION:
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
if __debug__:
from ubinascii import hexlify
def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
"""
Encrypts the provided `buffer` with AES-GCM (in place).
Returns a 16-byte long encryption tag.
"""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce)
iv = _get_iv_from_nonce(nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.encrypt_in_place(buffer)
return aes_ctx.finish()
def dec(
buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
) -> bool:
"""
Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
the tag computed in decryption, otherwise it returns `False`.
"""
iv = _get_iv_from_nonce(nonce)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.decrypt_in_place(buffer)
computed_tag = aes_ctx.finish()
return computed_tag == tag
class BusyDecoder:
def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
iv = _get_iv_from_nonce(nonce)
self.aes_ctx = aesgcm(key, iv)
self.aes_ctx.auth(auth_data)
def decrypt_part(self, part: utils.BufferType) -> None:
self.aes_ctx.decrypt_in_place(part)
def finish_and_check_tag(self, tag: bytes) -> bool:
computed_tag = self.aes_ctx.finish()
return computed_tag == tag
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
class Handshake:
"""
`Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
The following values should be saved for future use before disposing of this object:
- `h` - handshake hash, can be used to bind other values to the channel
- `key_receive` - key for decrypting incoming communication
- `key_send` - key for encrypting outgoing communication
"""
def __init__(self) -> None:
self.trezor_ephemeral_privkey: bytes
self.ck: bytes
self.k: bytes
self.h: bytes
self.key_receive: bytes
self.key_send: bytes
def handle_th1_crypto(
self,
device_properties: bytes,
host_ephemeral_pubkey: bytes,
) -> tuple[bytes, bytes, bytes]:
trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
self.trezor_ephemeral_privkey = curve25519.generate_secret()
trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
point = curve25519.multiply(
self.trezor_ephemeral_privkey, host_ephemeral_pubkey
)
self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
aes_ctx = aesgcm(self.k, IV_1)
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
if __debug__:
log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0)
aes_ctx.auth(self.h)
tag_to_encrypted_key = aes_ctx.finish()
encrypted_trezor_static_pubkey = (
encrypted_trezor_static_pubkey + tag_to_encrypted_key
)
self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
tag = aes_ctx.finish()
self.h = _hash_of_two(self.h, tag)
return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
def handle_th2_crypto(
self,
encrypted_host_static_pubkey: utils.BufferType,
encrypted_payload: utils.BufferType,
) -> None:
aes_ctx = aesgcm(self.k, IV_2)
# The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
# However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
# authentication of the gcm tag.
aes_ctx.auth(self.h) # Authenticate with the previous value of `h`
self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value
aes_ctx.decrypt_in_place(
memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
)
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1)
host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
tag = aes_ctx.finish()
if tag != encrypted_host_static_pubkey[-16:]:
raise ThpDecryptionError()
self.ck, self.k = _hkdf(
self.ck,
curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0)
tag = aes_ctx.finish()
if tag != encrypted_payload[-16:]:
raise ThpDecryptionError()
self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
self.key_receive, self.key_send = _hkdf(self.ck, b"")
if __debug__:
log.debug(
__name__,
"(key_receive: %s, key_send: %s)",
hexlify(self.key_receive),
hexlify(self.key_send),
)
def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
aes_ctx = aesgcm(self.key_send, IV_1)
encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
tag = aes_ctx.finish()
return encrypted_trezor_state + tag
def _derive_static_key_pair() -> tuple[bytes, bytes]:
node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
node = bip32.from_seed(device.get_device_secret(), "curve25519")
node.derive(node_int)
trezor_static_privkey = node.private_key()
trezor_static_pubkey = node.public_key()[1:33]
# Note: the first byte (\x01) of the public key is removed, as it
# only indicates the type of the elliptic curve used
return trezor_static_privkey, trezor_static_pubkey
def get_trezor_static_pubkey() -> bytes:
_, pubkey = _derive_static_key_pair()
return pubkey
def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]:
temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
ctx = sha256(part_1)
ctx.update(part_2)
return ctx.digest()
def _get_iv_from_nonce(nonce: int) -> bytes:
utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")

View File

@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
import usb
_WIRE_INTERFACE_USB = b"\x01"
# TODO _WIRE_INTERFACE_BLE = b"\x02"
if TYPE_CHECKING:
from trezorio import WireInterface
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")

View File

@ -0,0 +1,179 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from trezor.wire.message_handler import get_msg_type
from . import ChannelState, ThpError
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
return buffer
except Exception as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def get_write_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType
) -> utils.BufferType:
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
return _get_buffer_for_write(required_min_size, buffer)
return buffer
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
msg_type = msg.MESSAGE_WIRE_TYPE
if msg_type is None:
msg_type = get_msg_type(msg.MESSAGE_NAME)
assert msg_type is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_read(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_read - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Allocating a new buffer")
from .thp_main import get_raw_read_buffer
if length > len(get_raw_read_buffer()):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Required length is %d, where raw buffer has capacity only %d",
length,
len(get_raw_read_buffer()),
)
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]
def _get_buffer_for_write(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_write - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Creating a new write buffer from raw write buffer")
from .thp_main import get_raw_write_buffer
if length > len(get_raw_write_buffer()):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]

View File

@ -0,0 +1,262 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
import trezorui2
from trezor import loop, protobuf, workflow
from trezor.crypto import random
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import Container
from trezor import ui
from .channel import Channel
from .cpace import Cpace
pass
if __debug__:
from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc_unidirectional: bytes | None = None
def get_display_layout(self) -> ui.Layout:
from trezor import ui
# TODO have different layouts when there is only QR code or only Code Entry
qr_str = ""
code_str = ""
if self.code_qr_code is not None:
qr_str = self._get_code_qr_code_str()
if self.code_code_entry is not None:
code_str = self._get_code_code_entry_str()
return ui.Layout(
trezorui2.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=qr_str,
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + code_str,
path="",
xpubs=[],
)
)
def _get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def _get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.mailbox()
self.secret: bytes = random.bytes(16)
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str
async def handle(self, is_debug_session: bool = False) -> None:
# if __debug__:
# log.debug(__name__, "handle - start")
# if is_debug_session:
# import apps.debug
# apps.debug.DEBUG_CONTEXT = self
next_message: Message | None = None
while True:
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: Message = await self.incoming_message
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(message_handler.failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await handle_pairing_request_message(self, message)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None:
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above. # TODO not updated comments
if __debug__:
log.exception(__name__, exc)
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
name = message_handler.get_msg_name(message.type)
if name is None:
expected_type = protobuf.type_for_wire(message.type)
else:
expected_type = protobuf.type_for_name(name)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
if expected_wire_type is None:
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
assert expected_wire_type is not None
await self.write(msg)
del msg
return await self.read((expected_wire_type,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
async def handle_pairing_request_message(
pairing_ctx: PairingContext,
msg: protocol_common.Message,
) -> protocol_common.Message | None:
res_msg: protobuf.MessageType | None = None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
name = message_handler.get_msg_name(msg.type)
if name is None:
req_type = protobuf.type_for_wire(msg.type)
else:
req_type = protobuf.type_for_name(name)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
except UnexpectedMessageException as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
return exc.msg
except SilentError as exc:
if __debug__:
log.error(__name__, "SilentError: %s", exc.message)
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await pairing_ctx.write(res_msg)
return None

View File

@ -0,0 +1,446 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import (
KEY_LENGTH,
MANAGEMENT_SESSION_ID,
SESSION_ID_LENGTH,
TAG_LENGTH,
update_channel_last_used,
update_session_last_used,
)
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import Failure
from .. import message_handler
from ..errors import DataError
from ..protocol_common import Message
from . import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
ChannelState,
PacketHeader,
SessionState,
ThpDecryptionError,
ThpError,
ThpErrorType,
ThpInvalidDataError,
ThpUnallocatedSessionError,
)
from . import alternating_bit_protocol as ABP
from . import (
checksum,
control_byte,
get_enabled_pairing_methods,
get_encoded_device_properties,
session_manager,
)
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH, Handshake
from .writer import (
INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from typing import Awaitable
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from .channel import Channel
if __debug__:
from ubinascii import hexlify
from . import state_to_str
_TREZOR_STATE_UNPAIRED = b"\x00"
_TREZOR_STATE_PAIRED = b"\x01"
async def handle_received_message(
ctx: Channel, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_received_message")
if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done
try:
import micropython
print("micropython.mem_info() from received_message_handler.py")
micropython.mem_info()
print("Allocation count:", micropython.alloc_count()) # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
except AttributeError:
print(
"To show allocation count, create the build with TREZOR_MEMPERF=1"
)
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_HEADER_LENGTH
_check_checksum(message_length, message_buffer)
# Synchronization process
seq_bit = (ctrl_byte & 0x10) >> 4
ack_bit = (ctrl_byte & 0x08) >> 3
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
seq_bit,
ack_bit,
)
# 0: Update "last-time used"
update_channel_last_used(ctx.channel_id)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, ack_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected sequential bit
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Received message with an unexpected sequential bit")
await _send_ack(ctx, ack_bit=seq_bit)
raise ThpError("Received message with an unexpected sequential bit")
# 3: Send ACK in response
await _send_ack(ctx, ack_bit=seq_bit)
ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
try:
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte
)
except ThpUnallocatedSessionError as e:
error_message = Failure(code=FailureType.ThpUnallocatedSession)
await ctx.write(error_message, e.session_id)
except ThpDecryptionError:
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
ctx.clear()
except ThpInvalidDataError:
await ctx.write_error(ThpErrorType.INVALID_DATA)
ctx.clear()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_received_message - end")
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, ack_bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
return
# ACK is expected and it has correct sync bit
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Received ACK message with correct ack bit")
if ctx.transmission_loop is not None:
ctx.transmission_loop.stop_immediately()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Stopped transmission loop")
ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in termination of this function - any code after
# this await might not be executed
def _handle_message_to_app_or_channel(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> Awaitable[None]:
state = ctx.get_channel_state()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
if state is ChannelState.TH1:
return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
if state is ChannelState.TH2:
return _handle_state_TH2(ctx, message_length, ctrl_byte)
if _is_channel_state_pairing(state):
return _handle_pairing(ctx, message_length)
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
ctx.handshake = Handshake()
host_ephemeral_pubkey = bytearray(
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
)
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
ctx.handshake.handle_th1_crypto(
get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey
)
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"trezor ephemeral pubkey: %s",
hexlify(trezor_ephemeral_pubkey).decode(),
)
log.debug(
__name__,
"encrypted trezor masked static pubkey: %s",
hexlify(encrypted_trezor_static_pubkey).decode(),
)
log.debug(__name__, "tag: %s", hexlify(tag))
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
# send handshake init response message
ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
from apps.thp.credential_manager import validate_credential
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
if ctx.handshake is None:
raise Exception("Handshake object is not prepared. Retry handshake.")
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
ctx.handshake.handle_th2_crypto(
host_encrypted_static_pubkey, handshake_completion_request_noise_payload
)
ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
noise_payload = _decode_message(
ctx.buffer[
INIT_HEADER_LENGTH
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
enabled_methods = get_enabled_pairing_methods(ctx.iface)
for method in noise_payload.pairing_methods:
if method not in enabled_methods:
raise ThpInvalidDataError()
if method not in ctx.selected_pairing_methods:
ctx.selected_pairing_methods.append(method)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# key is decoded in handshake._handle_th2_crypto
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
paired: bool = False
if noise_payload.host_pairing_credential is not None:
try: # TODO change try-except for something better
paired = validate_credential(
noise_payload.host_pairing_credential,
host_static_pubkey,
)
except DataError as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
pass
trezor_state = _TREZOR_STATE_UNPAIRED
if paired:
trezor_state = _TREZOR_STATE_PAIRED
# send hanshake completion response
ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
ctx.handshake.get_handshake_completion_response(trezor_state),
)
ctx.handshake = None
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
)
if session_id not in ctx.sessions:
if session_id == MANAGEMENT_SESSION_ID:
s = session_manager.create_new_management_session(ctx)
else:
s = session_manager.get_session_from_cache(ctx, session_id)
if s is None:
raise ThpUnallocatedSessionError(session_id)
ctx.sessions[session_id] = s
loop.schedule(s.handle())
elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED:
raise ThpUnallocatedSessionError(session_id)
s = ctx.sessions[session_id]
update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big"))
s.incoming_message.put(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: Channel, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.put(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
def _decode_message(
buffer: bytes, msg_type: int, message_name: str | None = None
) -> protobuf.MessageType:
if __debug__:
log.debug(__name__, "decode message")
if message_name is not None:
expected_type = protobuf.type_for_name(message_name)
else:
expected_type = protobuf.type_for_wire(msg_type)
return message_handler.wrap_protobuf_load(buffer, expected_type)
def _is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False

View File

@ -0,0 +1,175 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import failure, find_handler
from ..protocol_common import Context, Message
from . import SessionState
if TYPE_CHECKING:
from typing import Awaitable, Container
from storage.cache_common import DataCache
from ..message_handler import HandlerFinder
from .channel import Channel
pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
if __debug__:
from trezor.utils import get_bytes_as_str
class GenericSessionContext(Context):
def __init__(self, channel: Channel, session_id: int) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel
self.session_id: int = session_id
self.incoming_message = loop.mailbox()
self.handler_finder: HandlerFinder = find_handler
async def handle(self) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._handle_debug()
next_message: Message | None = None
while True:
message = next_message
next_message = None
try:
if await self._handle_message(message):
loop.schedule(self.handle())
return
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_message = unexpected.msg
continue
except Exception as exc:
# Log and try again.
if __debug__:
log.exception(__name__, exc)
def _handle_debug(self) -> None:
log.debug(
__name__,
"handle - start (channel_id (bytes): %s, session_id: %d)",
get_bytes_as_str(self.channel_id),
self.session_id,
)
async def _handle_message(
self,
next_message: Message | None,
) -> bool:
try:
if next_message is not None:
# Process the message from previous run.
message = next_message
next_message = None
else:
# Wait for a new message from wire
message = await self.incoming_message
except protocol_common.WireError as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
await message_handler.handle_single_message(
self,
message,
self.handler_finder,
)
return _EXIT_LOOP
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message
if message.type not in expected_types:
if __debug__:
log.debug(
__name__,
"EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
str(expected_types),
str(message.type),
)
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.channel.write(msg, self.session_id, force=True)
def get_session_state(self) -> SessionState: ...
class ManagementSessionContext(GenericSessionContext):
def __init__(
self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID
) -> None:
super().__init__(channel_ctx, session_id)
def get_session_state(self) -> SessionState:
return SessionState.MANAGEMENT
class SessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
session_id = int.from_bytes(session_cache.session_id, "big")
super().__init__(channel_ctx, session_id)
self.session_cache = session_cache
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
state = int.from_bytes(self.session_cache.state, "big")
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def release(self) -> None:
if self.session_cache is not None:
cache_thp.clear_session(self.session_cache)
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
return self.session_cache

View File

@ -0,0 +1,37 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from .session_context import (
GenericSessionContext,
ManagementSessionContext,
SessionContext,
)
if TYPE_CHECKING:
from .channel import Channel
def create_new_session(channel_ctx: Channel) -> SessionContext:
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def create_new_management_session(
channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
) -> ManagementSessionContext:
return ManagementSessionContext(channel_ctx, session_id)
def get_session_from_cache(
channel_ctx: Channel, session_id: int
) -> GenericSessionContext | None:
cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id)
for s in cached_sessions:
print(s, s.channel_id, int.from_bytes(s.session_id, "big"))
if (
s.channel_id == channel_ctx.channel_id
and int.from_bytes(s.session_id, "big") == session_id
):
return SessionContext(channel_ctx, s)
return None

View File

@ -0,0 +1,187 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils
from . import (
CHANNEL_ALLOCATION_REQ,
CODEC_V1,
ChannelState,
PacketHeader,
ThpError,
ThpErrorType,
channel_manager,
checksum,
get_channel_allocation_response,
writer,
)
from .channel import Channel
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
PACKET_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12)
_READ_BUFFER: bytearray
_WRITE_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {}
def set_read_buffer(buffer: bytearray) -> None:
global _READ_BUFFER
_READ_BUFFER = buffer
def set_write_buffer(buffer: bytearray) -> None:
global _WRITE_BUFFER
_WRITE_BUFFER = buffer
def get_raw_read_buffer() -> bytearray:
global _READ_BUFFER
return _READ_BUFFER
def get_raw_write_buffer() -> bytearray:
global _WRITE_BUFFER
return _WRITE_BUFFER
async def thp_main_loop(iface: WireInterface) -> None:
global _CHANNELS
global _READ_BUFFER
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
try:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "thp_main_loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
await _handle_codec_v1(iface, packet)
continue
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in _CHANNELS:
await _handle_allocated(iface, cid, packet)
else:
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
log.exception(__name__, e)
async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
# If the received packet is not an initial codec_v1 packet, do not send error message
if not packet[1:3] == b"##":
return
if __debug__:
log.debug(__name__, "Received codec_v1 message, returning error")
error_message = _get_codec_v1_error_message()
await writer.write_packet_to_wire(iface, error_message)
async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None:
global _READ_BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on the broadcast channel")
length, nonce = ustruct.unpack(">H8s", packet[3:])
payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
if not checksum.is_valid(
payload[-4:],
packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
):
raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel
response_data = get_channel_allocation_response(
nonce, new_channel.channel_id, iface
)
response_header = PacketHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", cid)
await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = _CHANNELS[cid]
if channel is None:
await _handle_unallocated(iface, cid)
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
x = channel.receive_packet(packet)
if x is not None:
await x
async def _handle_unallocated(iface: WireInterface, cid: int) -> None:
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
await write_payload_to_wire_and_add_checksum(iface, header, data)
def _get_buffer_for_payload(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(PACKET_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]
def _get_codec_v1_error_message() -> bytes:
# Codec_v1 magic constant "?##" + Failure message type + msg_size
# + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B
ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
return ERROR_MSG

View File

@ -0,0 +1,54 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import loop
from .writer import write_payload_to_wire_and_add_checksum
if TYPE_CHECKING:
from . import PacketHeader
from .channel import Channel
MAX_RETRANSMISSION_COUNT = const(50)
MIN_RETRANSMISSION_COUNT = const(2)
class TransmissionLoop:
def __init__(
self, channel: Channel, header: PacketHeader, transport_payload: bytes
) -> None:
self.channel: Channel = channel
self.header: PacketHeader = header
self.transport_payload: bytes = transport_payload
self.wait_task: loop.spawn | None = None
self.min_retransmisson_count_achieved: bool = False
async def start(
self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT
) -> None:
self.min_retransmisson_count_achieved = False
for i in range(max_retransmission_count):
if i >= MIN_RETRANSMISSION_COUNT:
self.min_retransmisson_count_achieved = True
await write_payload_to_wire_and_add_checksum(
self.channel.iface, self.header, self.transport_payload
)
self.wait_task = loop.spawn(self._wait(i))
try:
await self.wait_task
except loop.TaskClosed:
self.wait_task = None
break
def stop_immediately(self) -> None:
if self.wait_task is not None:
self.wait_task.close()
self.wait_task = None
async def _wait(self, counter: int = 0) -> None:
timeout_ms = round(10200 - 1010000 / (counter + 100))
await loop.sleep(timeout_ms)
def __del__(self) -> None:
self.stop_immediately()

View File

@ -0,0 +1,92 @@
from micropython import const
from trezorcrypto import crc
from typing import TYPE_CHECKING
from trezor import io, log, loop, utils
from . import PacketHeader
INIT_HEADER_LENGTH = const(5)
CONT_HEADER_LENGTH = const(3)
PACKET_LENGTH = const(64)
CHECKSUM_LENGTH = const(4)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable, Sequence
def write_payload_to_wire_and_add_checksum(
iface: WireInterface, header: PacketHeader, transport_payload: bytes
) -> Awaitable[None]:
header_checksum: int = crc.crc32(header.to_bytes())
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
CHECKSUM_LENGTH, "big"
)
data = (transport_payload, checksum)
return write_payloads_to_wire(iface, header, data)
async def write_payloads_to_wire(
iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
) -> None:
n_of_data = len(data)
total_length = sum(len(item) for item in data)
current_data_idx = 0
current_data_offset = 0
packet = bytearray(PACKET_LENGTH)
header.pack_to_init_buffer(packet)
packet_offset: int = INIT_HEADER_LENGTH
packet_number = 0
nwritten = 0
while nwritten < total_length:
if packet_number == 1:
header.pack_to_cont_buffer(packet)
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
packet[:] = bytearray(PACKET_LENGTH)
header.pack_to_cont_buffer(packet)
while True:
n = utils.memcpy(
packet, packet_offset, data[current_data_idx], current_data_offset
)
packet_offset += n
current_data_offset += n
nwritten += n
if packet_offset < PACKET_LENGTH:
current_data_idx += 1
current_data_offset = 0
if current_data_idx >= n_of_data:
break
elif packet_offset == PACKET_LENGTH:
break
else:
raise Exception("Should not happen!!!")
packet_number += 1
packet_offset = CONT_HEADER_LENGTH
# write packet to wire (in-lined)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
written_by_iface: int = 0
while written_by_iface < len(packet):
await loop.wait(iface.iface_num() | io.POLL_WRITE)
written_by_iface = iface.write(packet)
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
while True:
await loop.wait(iface.iface_num() | io.POLL_WRITE)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
n_written = iface.write(packet)
if n_written == len(packet):
return

View File

@ -1,9 +1,9 @@
import utime
from typing import TYPE_CHECKING
import storage.cache
import storage.cache_common as cache_common
from trezor import log, loop
from trezor.enums import MessageType
from trezor.enums import MessageType, ThpMessageType
if TYPE_CHECKING:
from typing import Callable
@ -17,9 +17,14 @@ if __debug__:
from trezor import utils
if utils.USE_THP:
protocol_specific = ThpMessageType.ThpCreateNewSession
else:
protocol_specific = MessageType.Initialize
ALLOW_WHILE_LOCKED = (
MessageType.Initialize,
protocol_specific,
MessageType.EndSession,
MessageType.GetFeatures,
MessageType.Cancel,
@ -153,7 +158,7 @@ def close_others() -> None:
if not task.is_running():
task.close()
storage.cache.homescreen_shown = None
cache_common.homescreen_shown = None
# if tasks were running, closing the last of them will run start_default
@ -211,11 +216,11 @@ class IdleTimer:
time and saves it to storage.cache. This is done to avoid losing an
active timer when workflow restart happens and tasks are lost.
"""
if _restore_from_cache and storage.cache.autolock_last_touch is not None:
now = storage.cache.autolock_last_touch
if _restore_from_cache and cache_common.autolock_last_touch is not None:
now = cache_common.autolock_last_touch
else:
now = utime.ticks_ms()
storage.cache.autolock_last_touch = now
cache_common.autolock_last_touch = now
for callback, task in self.tasks.items():
timeout_us = self.timeouts[callback]

View File

@ -0,0 +1,17 @@
from trezor.loop import wait
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)

42
core/tests/myTests.sh Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env bash
declare -a results
declare -i passed=0 failed=0 exit_code=0
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
print_summary() {
echo
echo 'Summary:'
echo '-------------------'
printf '%b\n' "${results[@]}"
if [ $exit_code == 0 ]; then
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
else
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
fi
}
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
cd $(dirname $0)
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
declare -i num_of_tests=${#tests[@]}
for test_case in ${tests[@]}; do
echo ${MICROPYTHON}
echo ${test_case}
echo
if $MICROPYTHON $test_case; then
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
((passed++))
else
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
((failed++))
exit_code=1
fi
done
print_summary
exit $exit_code

View File

@ -1,4 +1,4 @@
from common import H_, await_result, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor import wire
@ -11,15 +11,35 @@ from trezor.messages import (
TxInput,
TxOutput,
)
from trezor.wire import context
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
from apps.bitcoin.sign_tx.bitcoin import Bitcoin
from apps.bitcoin.sign_tx.tx_info import TxInfo
from apps.common import coins
from trezor.wire.codec.codec_context import CodecContext
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
class TestApprover(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
self.coin = coins.by_name("Bitcoin")
self.fee_rate_percent = 0.3
@ -47,7 +67,8 @@ class TestApprover(unittest.TestCase):
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDTAPROOT,
)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def make_coinjoin_request(self, inputs):
return CoinJoinRequest(

View File

@ -1,16 +1,36 @@
from common import H_, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
from trezor.wire import context
from apps.bitcoin.authorization import CoinJoinAuthorization
from apps.common import coins
from trezor.wire.codec.codec_context import CodecContext
_ROUND_ID_LEN = 32
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
class TestAuthorization(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
coin = coins.by_name("Bitcoin")
@ -26,7 +46,8 @@ class TestAuthorization(unittest.TestCase):
)
self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch.

View File

@ -1,230 +1,519 @@
from common import * # isort:skip
from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage
from storage import cache
from storage import cache, cache_codec
from trezor.messages import EndSession, Initialize
from apps.base import handle_EndSession, handle_Initialize
from apps.base import handle_EndSession
from trezor.wire.codec.codec_context import CodecContext
KEY = 0
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from storage import cache_thp
from trezor.wire.thp import ChannelState
from trezor.wire.thp.session_context import SessionContext
# Function moved from cache.py, as it was not used there
def is_session_started() -> bool:
return cache._active_session_idx is not None
_PROTOCOL_CACHE = cache_thp
else:
_PROTOCOL_CACHE = cache_codec
def is_session_started() -> bool:
return cache_codec.get_active_session() is not None
def get_active_session():
return cache_codec.get_active_session()
class TestStorageCache(unittest.TestCase):
class TestStorageCache(unittest.TestCase): # noqa: F405
def setUp(self):
cache.clear_all()
def test_start_session(self):
session_id_a = cache.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache.start_session()
self.assertNotEqual(session_id_a, session_id_b)
if utils.USE_THP:
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
def __init__(self):
thp_common.suppres_debug_log()
# xthp_common.prepare_context()
# config.init()
super().__init__()
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(is_session_started())
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
cache.clear_all()
# ending an ended session should be a no-op
cache.end_current_session()
self.assertFalse(is_session_started())
def test_new_channel_and_session(self):
channel = thp_common.get_new_channel(self.interface)
session_id_a = cache.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
# Assert that channel is created without any sessions
self.assertEqual(len(channel.sessions), 0)
# create a new session
session_id_b = cache.start_session()
# switch back to original session
session_id = cache.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache.end_current_session()
# switch back to B
session_id = cache.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
cid_1 = channel.channel_id
session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
session_1 = SessionContext(channel, session_cache_1)
self.assertEqual(session_1.channel_id, cid_1)
def test_session_queue(self):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
session_2 = SessionContext(channel, session_cache_2)
self.assertEqual(session_2.channel_id, cid_1)
self.assertEqual(session_1.channel_id, session_2.channel_id)
self.assertNotEqual(session_1.session_id, session_2.session_id)
def test_get_set(self):
session_id1 = cache.start_session()
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
channel_2 = thp_common.get_new_channel(self.interface)
cid_2 = channel_2.channel_id
self.assertNotEqual(cid_1, cid_2)
session_id2 = cache.start_session()
cache.set(KEY, b"world")
self.assertEqual(cache.get(KEY), b"world")
session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
session_3 = SessionContext(channel_2, session_cache_3)
self.assertEqual(session_3.channel_id, cid_2)
cache.start_session(session_id2)
self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
# Sessions 1 and 3 should have different channel_id, but the same session_id
self.assertNotEqual(session_1.channel_id, session_3.channel_id)
self.assertEqual(session_1.session_id, session_3.session_id)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
def test_get_set_int(self):
session_id1 = cache.start_session()
cache.set_int(KEY, 1234)
self.assertEqual(cache.get_int(KEY), 1234)
# Check that session data IS in cache for created sessions ONLY
for i in range(3):
self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
session_id2 = cache.start_session()
cache.set_int(KEY, 5678)
self.assertEqual(cache.get_int(KEY), 5678)
# Check that session data IS NOT in cache after cache.clear_all()
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.channel_id, b"")
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.state, b"\x00")
cache.start_session(session_id2)
self.assertEqual(cache.get_int(KEY), 5678)
cache.start_session(session_id1)
self.assertEqual(cache.get_int(KEY), 1234)
def test_channel_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
channels = []
for i in range(cache_thp._MAX_CHANNELS_COUNT):
channels.append(thp_common.get_new_channel(self.interface))
channel_ids = [channel.channel_cache.channel_id for channel in channels]
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
cache.get_int(KEY)
# Assert that each channel_id is unique and that cache and list of channels
# have the same "channels" on the same indexes
for i in range(len(channel_ids)):
self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
for j in range(i + 1, len(channel_ids)):
self.assertNotEqual(channel_ids[i], channel_ids[j])
def test_delete(self):
session_id1 = cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
# Create a new channel that is over the capacity
new_channel = thp_common.get_new_channel(self.interface)
for c in channels:
self.assertNotEqual(c.channel_id, new_channel.channel_id)
cache.set(KEY, b"hello")
cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
# Test that the oldest (least used) channel was replaced (_CHANNELS[0])
self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
# Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
# assert that it is not replaced when creating a new channel
cache_thp.update_channel_last_used(channel_ids[1])
new_new_channel = thp_common.get_new_channel(self.interface)
self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
def test_decorators(self):
run_count = 0
cache.start_session()
# Assert that it was in fact the _CHANNEL[2] that was replaced
self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
self.assertEqual(
cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
)
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
def test_session_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
# cache is empty
self.assertIsNone(cache.get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
sesions_A = []
cid = []
sid = []
for i in range(3):
sesions_A.append(cache_thp.get_new_session(channel_cache_A))
cid.append(sesions_A[i].channel_id)
sid.append(sesions_A[i].session_id)
@cache.stored_async(KEY)
async def async_func():
nonlocal run_count
run_count += 1
return b"bar"
sessions_B = []
for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
sessions_B.append(cache_thp.get_new_session(channel_cache_B))
# cache is still full
self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1)
for i in range(3):
self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
cache.start_session()
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# awaitable is also run only once
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
new_session = cache_thp.get_new_session(channel_cache_B)
self.assertEqual(new_session, cache_thp._SESSIONS[0])
self.assertNotEqual(new_session.channel_id, cid[0])
self.assertNotEqual(new_session.session_id, sid[0])
def test_empty_value(self):
cache.start_session()
# Assert that updating "last used" for session on channel A increases also
# the "last usage" of channel A.
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
cache_thp.update_session_last_used(
channel_cache_A.channel_id, sesions_A[1].session_id
)
self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"")
self.assertEqual(cache.get(KEY), b"")
new_new_session = cache_thp.get_new_session(channel_cache_B)
cache.delete(KEY)
run_count = 0
# Assert that creating a new session on channel B shifts the "last usage" again
# and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
def test_clear(self):
channel_A = thp_common.get_new_channel(self.interface)
channel_B = thp_common.get_new_channel(self.interface)
cid_A = channel_A.channel_id
cid_B = channel_B.channel_id
sessions = []
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
for i in range(3):
sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
# calling Initialize without an ID allocates a new one
session_id = cache.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# Assert that clearing of channel A works
self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
# store "hello"
cache.set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(cache.get(KEY))
# store "hello" again
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
channel_A.clear()
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
self.assertEqual(channel_A.channel_cache.channel_id, b"")
self.assertEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), b"hello")
# Assert that clearing channel A also cleared all its sessions
for i in range(3):
self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
def test_EndSession(self):
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
cache.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.channel_id, b"")
for channel in cache_thp._CHANNELS:
self.assertEqual(channel.channel_id, b"")
self.assertEqual(channel.last_usage, 0)
self.assertEqual(
cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
)
def test_get_set(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set(KEY, b"world")
self.assertEqual(session_2.get(KEY), b"world")
self.assertEqual(session_1.get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(session_1.get(KEY))
self.assertIsNone(session_2.get(KEY))
def test_get_set_int(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set_int(KEY, 1234)
self.assertEqual(session_1.get_int(KEY), 1234)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set_int(KEY, 5678)
self.assertEqual(session_2.get_int(KEY), 5678)
self.assertEqual(session_1.get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(session_1.get_int(KEY))
self.assertIsNone(session_2.get_int(KEY))
def test_get_set_bool(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
with self.assertRaises(AssertionError) as e:
session_1.set_bool(KEY, True)
self.assertEqual(e.value.value, "Field does not have zero length!")
# Change length of first session field to 0 so that the length check passes
session_1.fields = (0,) + session_1.fields[1:]
# with self.assertRaises(AssertionError) as e:
session_1.set_bool(KEY, True)
self.assertEqual(session_1.get_bool(KEY), True)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
session_2.set_bool(KEY, False)
self.assertEqual(session_2.get_bool(KEY), False)
self.assertEqual(session_1.get_bool(KEY), True)
cache.clear_all()
# Default value is False
self.assertFalse(session_1.get_bool(KEY))
self.assertFalse(session_2.get_bool(KEY))
def test_delete(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_1.delete(KEY)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_2.get(KEY))
session_2.set(KEY, b"hello")
self.assertEqual(session_2.get(KEY), b"hello")
session_2.delete(KEY)
self.assertIsNone(session_2.get(KEY))
self.assertEqual(session_1.get(KEY), b"hello")
else:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def test_start_session(self):
session_id_a = cache_codec.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
self.assertIsNone(get_active_session())
for session in cache_codec._SESSIONS:
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
def test_end_session(self):
session_id = cache_codec.start_session()
self.assertTrue(is_session_started())
get_active_session().set(KEY, b"A")
cache_codec.end_current_session()
self.assertFalse(is_session_started())
self.assertIsNone(get_active_session())
# ending an ended session should be a no-op
cache_codec.end_current_session()
self.assertFalse(is_session_started())
session_id_a = cache_codec.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(get_active_session().get(KEY))
# create a new session
session_id_b = cache_codec.start_session()
# switch back to original session
session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache_codec.end_current_session()
# switch back to B
session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self):
session_id = cache_codec.start_session()
self.assertEqual(cache_codec.start_session(session_id), session_id)
get_active_session().set(KEY, b"A")
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache_codec.start_session()
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
self.assertIsNone(get_active_session().get(KEY))
def test_get_set(self):
session_id1 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"hello")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
session_id2 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"world")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id2)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id1)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(cache_codec.get_active_session())
def test_get_set_int(self):
session_id1 = cache_codec.start_session()
get_active_session().set_int(KEY, 1234)
self.assertEqual(get_active_session().get_int(KEY), 1234)
session_id2 = cache_codec.start_session()
get_active_session().set_int(KEY, 5678)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id2)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(get_active_session())
def test_delete(self):
session_id1 = cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get(KEY), b"hello")
def test_decorators(self):
run_count = 0
cache_codec.start_session()
from apps.common.cache import stored
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
# cache is empty
self.assertIsNone(get_active_session().get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(get_active_session().get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
def test_empty_value(self):
cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"")
self.assertEqual(get_active_session().get(KEY), b"")
get_active_session().delete(KEY)
run_count = 0
from apps.common.cache import stored
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
if not utils.USE_THP:
@mock_storage
def test_Initialize(self):
from apps.base import handle_Initialize
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
# calling Initialize without an ID allocates a new one
session_id = cache_codec.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# store "hello"
get_active_session().set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(get_active_session().get(KEY))
# store "hello" again
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
# supplying a different session ID starts a new session
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(get_active_session().get(KEY))
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(get_active_session().get(KEY), b"hello")
def test_EndSession(self):
self.assertIsNone(get_active_session())
cache_codec.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(get_active_session().get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertIsNone(cache_codec.get_active_session())
if __name__ == "__main__":

View File

@ -2,27 +2,10 @@ from common import * # isort:skip
import ustruct
from mock_wire_interface import MockHID
from trezor import io
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
from trezor.wire.codec import codec_v1
MESSAGE_TYPE = 0x4242

View File

@ -0,0 +1,94 @@
from common import * # isort:skip
if utils.USE_THP:
from trezor.wire.thp import checksum
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolChecksum(unittest.TestCase):
vectors_correct = [
(
b"",
b"\x00\x00\x00\x00",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAD",
),
(
bytes("a", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("123456789", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
(
b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
b"\x9B\xD3\x66\xAE",
),
(
b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
b"\x6B\xA4\xEC\x92",
),
]
vectors_incorrect = [
(
b"",
b"\x00\x00\x00\x00\x00",
),
(
b"",
b"",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAE",
),
(
bytes("A", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc ", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("1234567890", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"1234567890123456789012345678901234567890123456789012345678901234567890123456789",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
]
def test_computation(self):
for data, chksum in self.vectors_correct:
self.assertEqual(checksum.compute(data), chksum)
def test_validation_correct(self):
for data, chksum in self.vectors_correct:
self.assertTrue(checksum.is_valid(chksum, data))
def test_validation_incorrect(self):
for data, chksum in self.vectors_incorrect:
self.assertFalse(checksum.is_valid(chksum, data))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,66 @@
from common import * # isort:skip
if utils.USE_THP:
import thp_common
from trezor import config
from trezor.messages import ThpCredentialMetadata
from apps.thp import credential_manager
def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes:
metadata = ThpCredentialMetadata(host_name=host_name)
return credential_manager.issue_credential(host_static_pubkey, metadata)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCredentialManager(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
config.init()
config.wipe()
def test_derive_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertEqual(len(key1), 32)
self.assertEqual(key1, key2)
def test_invalidate_cred_auth_key(self):
key1 = credential_manager.derive_cred_auth_key()
credential_manager.invalidate_cred_auth_key()
key2 = credential_manager.derive_cred_auth_key()
self.assertNotEqual(key1, key2)
def test_credentials(self):
DUMMY_KEY_1 = b"\x00\x00"
DUMMY_KEY_2 = b"\xff\xff"
HOST_NAME_1 = "host_name"
HOST_NAME_2 = "different host_name"
cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertEqual(cred_1, cred_2)
cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_3)
self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2))
credential_manager.invalidate_cred_auth_key()
cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_4)
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,156 @@
from common import * # isort:skip
from trezorcrypto import aesgcm, curve25519
import storage
if utils.USE_THP:
import thp_common
from trezor.wire.thp import crypto
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
def get_dummy_device_secret():
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCrypto(unittest.TestCase):
if utils.USE_THP:
handshake = Handshake()
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
vectors_enc = [
(
key_1,
0,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"e2c9dd152fbee5821ea7",
b"10625812de81b14a46b9f1e5100a6d0c",
),
(
key_1,
1,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"79811619ddb07c2b99f8",
b"71c6b872cdc499a7e9a3c7441f053214",
),
(
key_1,
369,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"c1200f8a7ae9a6d32cef0fff878d55c2",
),
(
key_1,
369,
b"\x55\x64\x73\x82\x91",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"693ac160cd93a20f7fc255f049d808d0",
),
]
# 0:chaining key, 1:input, 2:output_1, 3:output:2
vectors_hkdf = [
(
crypto.PROTOCOL_NAME,
b"\x01\x02",
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
),
(
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
),
]
vectors_iv = [
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
]
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
utils.DISABLE_ENCRYPTION = False
def test_encryption(self):
for v in self.vectors_enc:
buffer = bytearray(v[3])
tag = crypto.enc(buffer, v[0], v[1], v[2])
self.assertEqual(hexlify(buffer), v[4])
self.assertEqual(hexlify(tag), v[5])
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
self.assertEqual(buffer, v[3])
def test_hkdf(self):
for v in self.vectors_hkdf:
ck, k = crypto._hkdf(v[0], v[1])
self.assertEqual(hexlify(ck), v[2])
self.assertEqual(hexlify(k), v[3])
def test_iv_from_nonce(self):
for v in self.vectors_iv:
x = v[0]
y = x.to_bytes(8, "big")
iv = crypto._get_iv_from_nonce(v[0])
self.assertEqual(iv, v[1])
with self.assertRaises(AssertionError) as e:
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
def test_incorrect_vectors(self):
pass
def test_th1_crypto(self):
storage.device.get_device_secret = get_dummy_device_secret
handshake = self.handshake
host_ephemeral_privkey = curve25519.generate_secret()
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
def test_th2_crypto(self):
handshake = self.handshake
host_static_privkey = curve25519.generate_secret()
host_static_pubkey = curve25519.publickey(host_static_privkey)
aes_ctx = aesgcm(handshake.k, IV_2)
aes_ctx.auth(handshake.h)
encrypted_host_static_pubkey = bytearray(
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
)
# Code to encrypt Host's noise encrypted payload correctly:
protomsg = bytearray(b"\x10\x02\x10\x03")
temp_k = handshake.k
temp_h = handshake.h
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
_, temp_k = crypto._hkdf(
handshake.ck,
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(temp_k, IV_1)
aes_ctx.encrypt_in_place(protomsg)
aes_ctx.auth(temp_h)
tag = aes_ctx.finish()
encrypted_payload = bytearray(protomsg + tag)
# end of encrypted payload generation
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,377 @@
from common import * # isort:skip
from mock_wire_interface import MockHID
from trezor import config, io, protobuf
from trezor.crypto.curve import curve25519
from trezor.enums import ThpMessageType
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.protocol_common import Message
if utils.USE_THP:
from typing import TYPE_CHECKING
import thp_common
from storage import cache_thp
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from trezor.crypto import elligator2
from trezor.enums import ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCpaceHost,
ThpCodeEntryTag,
ThpCredentialRequest,
ThpEndRequest,
ThpStartPairingRequest,
)
from trezor.wire.thp import thp_main
from trezor.wire.thp import ChannelState, checksum, interface_manager
from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext
from apps.thp import pairing
if TYPE_CHECKING:
from trezor.wire import WireInterface
def get_dummy_key() -> bytes:
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
def dummy_encode_iface(iface: WireInterface):
return thp_common._MOCK_INTERFACE_HID
def send_channel_allocation_request(
interface: WireInterface, nonce: bytes | None = None
) -> bytes:
if nonce is None or len(nonce) != 8:
nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
header = b"\x40\xff\xff\x00\x0c"
chksum = checksum.compute(header + nonce)
cid_req = header + nonce + chksum
gen = thp_main.thp_main_loop(interface)
expected_channel_index = cache_thp._get_next_channel_index()
gen.send(None)
gen.send(cid_req)
gen.send(None)
response_data = (
b"\x0a\x04\x54\x32\x54\x31\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
)
response_without_crc = (
b"\x41\xff\xff\x00\x20"
+ nonce
+ cache_thp._CHANNELS[expected_channel_index].channel_id
+ response_data
)
chkcsum = checksum.compute(response_without_crc)
expected_response = response_without_crc + chkcsum + b"\x00" * 27
return expected_response
def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
return int.from_bytes(channel_allocation_response[13:15], "big")
def get_ack(channel_id: bytes) -> bytes:
if len(channel_id) != 2:
raise Exception("Channel id should by two bytes long")
return (
b"\x20"
+ channel_id
+ b"\x00\x04"
+ checksum.compute(b"\x20" + channel_id + b"\x00\x04")
+ b"\x00" * 55
)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocol(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
interface_manager.encode_iface = dummy_encode_iface
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
buffer2 = bytearray(256)
thp_main.set_read_buffer(buffer)
thp_main.set_write_buffer(buffer2)
interface_manager.decode_iface = thp_common.dummy_decode_iface
def test_codec_message(self):
self.assertEqual(len(self.interface.data), 0)
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# There should be a failiure response to received init packet (starts with "?##")
test_codec_message = b"?## Some data"
gen.send(test_codec_message)
gen.send(None)
self.assertEqual(len(self.interface.data), 1)
expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10"
self.assertEqual(
self.interface.data[-1][: len(expected_response)], expected_response
)
# There should be no response for continuation packet (starts with "?" only)
test_codec_message_2 = b"? Cont packet"
gen.send(test_codec_message_2)
with self.assertRaises(TypeError) as e:
gen.send(None)
self.assertEqual(e.value.value, "object with buffer protocol required")
self.assertEqual(len(self.interface.data), 1)
def test_message_on_unallocated_channel(self):
gen = thp_main.thp_main_loop(self.interface)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
message_to_channel_789a = (
b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
gen.send(message_to_channel_789a)
gen.send(None)
unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
utils.get_bytes_as_str(self.interface.data[-1]),
unallocated_chanel_error_on_channel_789a,
)
def test_channel_allocation(self):
self.assertEqual(len(thp_main._CHANNELS), 0)
for c in cache_thp._CHANNELS:
self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED)
expected_channel_index = cache_thp._get_next_channel_index()
expected_response = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-1], expected_response)
cid = cache_thp._CHANNELS[expected_channel_index].channel_id
self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS)
self.assertEqual(len(thp_main._CHANNELS), 1)
# test channel's default state is TH1:
cid = get_channel_id_from_response(self.interface.data[-1])
self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
def test_invalid_encrypted_tag(self):
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
def test_channel_errors(self):
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
# test invalid tag in handshake phase
cid_2 = get_channel_id_from_response(expected_response_1)
cid_2_bytes = cid_2.to_bytes(2, "big")
channel = thp_main._CHANNELS[cid_2]
channel.iface = self.interface
channel.set_channel_state(ChannelState.TH2)
message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
# gen.send(message_with_invalid_tag)
# gen.send(None)
# gen.send(None)
# for i in self.interface.data:
# print(utils.get_bytes_as_str(i))
def test_skip_pairing(self):
config.init()
config.wipe()
channel = next(iter(thp_main._CHANNELS.values()))
channel.selected_pairing_methods = [
ThpPairingMethod.NoMethod,
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
# Teardown: set back initial channel state value
channel.set_channel_state(ChannelState.TH1)
def TODO_test_pairing(self):
config.init()
config.wipe()
cid = get_channel_id_from_response(
send_channel_allocation_request(self.interface)
)
channel = thp_main._CHANNELS[cid]
channel.selected_pairing_methods = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
with self.assertRaises(UnexpectedMessage) as e:
pairing.handle_pairing_request(pairing_ctx, request_message)
print(e.value.message)
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
gen.send(None)
async def _dummy(ctx: PairingContext, expected_types):
return await ctx.read([1018, 1024])
pairing.show_display_data = _dummy
msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
protobuf.encode(buffer, msg_code_entry)
code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer)
gen.send(code_entry_challenge)
# tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
# tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
generator_host = elligator2.map_to_curve25519(pregenerator_host)
cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
cpace_host_public_key: bytes = curve25519.multiply(
cpace_host_private_key, generator_host
)
msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
# msg = ThpQrCodeTag(tag=tag_qrc)
# msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer)
gen.send(user_message)
tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
msg = ThpCodeEntryTag(tag=tag_ent)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer)
gen.send(user_message)
host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer)
gen.send(credential_request)
msg = ThpEndRequest()
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
end_request = Message(1012, buffer)
with self.assertRaises(StopIteration) as e:
gen.send(end_request)
print("response message:", e.value.value.MESSAGE_NAME)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,151 @@
from common import * # isort:skip
from typing import Any, Awaitable
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from trezor.wire.thp import writer
from trezor.wire.thp import ENCRYPTED, PacketHeader
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolWriter(unittest.TestCase):
short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
eight_longer_payloads_expected = [
b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
]
empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_with_checksum_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
def await_until_result(self, task: Awaitable) -> Any:
with self.assertRaises(StopIteration):
while True:
task.send(None)
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_write_empty_packet(self):
self.await_until_result(writer.write_packet_to_wire(self.interface, b""))
print(self.interface.data[0])
self.assertEqual(len(self.interface.data), 1)
self.assertEqual(self.interface.data[0], b"")
def test_write_empty_payload(self):
header = PacketHeader(ENCRYPTED, 4660, 4)
await_result(writer.write_payloads_to_wire(self.interface, header, (b"",)))
self.assertEqual(len(self.interface.data), 0)
def test_write_short_payload(self):
header = PacketHeader(ENCRYPTED, 4660, 5)
data = b"\x07"
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
def test_write_longer_payload(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 256)
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
for i in range(len(self.longer_payload_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.longer_payload_expected[i]
)
def test_write_eight_longer_payloads(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 2048)
self.await_until_result(
writer.write_payloads_to_wire(
self.interface, header, (data, data, data, data, data, data, data, data)
)
)
for i in range(len(self.eight_longer_payloads_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
)
def test_write_empty_payload_with_checksum(self):
header = PacketHeader(ENCRYPTED, 4660, 4)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
)
self.assertEqual(
hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
)
def test_write_longer_payload_with_checksum(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 256)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, data)
)
for i in range(len(self.longer_payload_with_checksum_expected)):
self.assertEqual(
hexlify(self.interface.data[i]),
self.longer_payload_with_checksum_expected[i],
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,338 @@
from common import * # isort:skip
import ustruct
from typing import TYPE_CHECKING
from mock_wire_interface import MockHID
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io
from trezor.utils import chunks
from trezor.wire.protocol_common import Message
if utils.USE_THP:
import thp_common
import trezor.wire.thp
from trezor.wire.thp import thp_main
from trezor.wire.thp import alternating_bit_protocol as ABP
from trezor.wire.thp import checksum
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.writer import PACKET_LENGTH
if TYPE_CHECKING:
from trezorio import WireInterface
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
if utils.USE_THP:
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def getPlaintext() -> bytes:
if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
return PLAINTEXT_0
async def deprecated_read_message(
iface: WireInterface, buffer: utils.BufferType
) -> Message:
return Message(-1, b"\x00")
async def deprecated_write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
pass
# This test suite is an adaptation of test_trezor.wire.codec_v1
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req_message)
gen.send(None)
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def _roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, 1
) # TODO use different session id
gen = deprecated_write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
print(utils.get_bytes_as_str(packet))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def _write_one_packet(self):
message = Message(MESSAGE_TYPE, b"")
gen = deprecated_write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(MESSAGE_TYPE, message_payload)
gen = deprecated_write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(trezor.wire.thp.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()

44
core/tests/thp_common.py Normal file
View File

@ -0,0 +1,44 @@
from trezor import utils
from trezor.wire.thp import ChannelState
if utils.USE_THP:
import unittest
from typing import TYPE_CHECKING
from mock_wire_interface import MockHID
from storage import cache_thp
from trezor.wire import context
from trezor.wire.thp import interface_manager
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.session_context import SessionContext
_MOCK_INTERFACE_HID = b"\x00"
if TYPE_CHECKING:
from trezor.wire import WireInterface
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
interface_manager.decode_iface = dummy_decode_iface
channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
channel = Channel(channel_cache)
channel.set_channel_state(ChannelState.TH1)
if channel_iface is not None:
channel.iface = channel_iface
return channel
def prepare_context() -> None:
channel = get_new_channel()
session_cache = cache_thp.get_new_session(channel.channel_cache)
session_ctx = SessionContext(channel, session_cache)
context.CURRENT_CONTEXT = session_ctx
if __debug__:
# Disable log.debug
def suppres_debug_log() -> None:
from trezor import log
log.debug = lambda name, msg, *args: None

View File

@ -2,7 +2,7 @@
import binascii
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) > 0:

View File

@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it,
it is just a single binary file that you can execute directly.
**Are you looking for a Trezor T emulator? This is most likely it.**
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317)
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332)
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346)
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369)
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392)
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408)
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430)
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455)
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456)
Build of our cryptographic library, which is then incorporated into the other builds.
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485)
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501)
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518)
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537)
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558)
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559)
Regular version (not only Bitcoin) of above.
**Are you looking for a Trezor One emulator? This is most likely it.**
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573)
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591)
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617)
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618)
Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses.
Bitcoin-only version.
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634)
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635)
---
## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml)

View File

@ -191,6 +191,12 @@ void fsm_sendFailure(FailureType code, const char *text)
case FailureType_Failure_InvalidSession:
text = _("Invalid session");
break;
case FailureType_Failure_ThpUnallocatedSession:
text = _("Unallocated session");
break;
case FailureType_Failure_InvalidProtocol:
text = _("Invalid protocol");
break;
case FailureType_Failure_FirmwareError:
text = _("Firmware error");
break;

View File

@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro
EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \
UnlockBootloader AuthenticateDevice AuthenticityProof \
Solana StellarClaimClaimableBalanceOp \
ChangeLanguage TranslationDataRequest TranslationDataAck \
ChangeLanguage TranslationDataRequest TranslationDataAck Thp \
SetBrightness DebugLinkOptigaSetSecMax \
BenchmarkListNames BenchmarkRun BenchmarkNames BenchmarkResult

View File

@ -0,0 +1 @@
../../vendor/trezor-common/protob/messages-thp.proto

1555
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ trezor = {path = "./python", develop = true}
scons = "*"
protobuf = "*"
nanopb = "^0.4.3"
appdirs ="*"
## test tools
pytest = "^8"

View File

@ -6,3 +6,4 @@ libusb1>=1.6.4
construct>=2.9,!=2.10.55
typing_extensions>=4.7.1
construct-classes>=0.1.2
appdirs>=1.4.4

View File

@ -112,7 +112,7 @@ class Emulator:
start = time.monotonic()
try:
while True:
if transport._ping():
if transport.ping():
break
if self.process.poll() is not None:
raise RuntimeError("Emulator process died")

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata
from . import device
from .client import TrezorClient
from .transport.session import Session
try:
cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device(
client: TrezorClient,
session: Session,
challenge: bytes | None = None,
*,
whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None:
challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge)
resp = device.authenticate(session, challenge)
return verify_authentication_response(
challenge,

View File

@ -20,17 +20,17 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
@expect(messages.BenchmarkNames)
def list_names(
client: "TrezorClient",
session: "Session",
) -> "MessageType":
return client.call(messages.BenchmarkListNames())
return session.call(messages.BenchmarkListNames())
@expect(messages.BenchmarkResult)
def run(client: "TrezorClient", name: str) -> "MessageType":
return client.call(messages.BenchmarkRun(name=name))
def run(session: "Session", name: str) -> "MessageType":
return session.call(messages.BenchmarkRun(name=name))

View File

@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import expect, session
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -42,16 +42,15 @@ def get_address(
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
session: "Session", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
)
@session
def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0]
tx_msg = tx_json.copy()
@ -60,7 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope)
response = session.call(envelope)
if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError(
@ -77,7 +76,7 @@ def sign_tx(
else:
raise ValueError("can not determine msg type")
response = client.call(msg)
response = session.call(msg)
if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError(

View File

@ -13,7 +13,6 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings
from copy import copy
from decimal import Decimal
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session
from .tools import expect, prepare_message_bytes
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
class ScriptSig(TypedDict):
asm: str
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
@expect(messages.PublicKey)
def get_public_node(
client: "TrezorClient",
session: "Session",
n: "Address",
ecdsa_curve_name: Optional[str] = None,
show_display: bool = False,
@ -116,13 +115,13 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetPublicKey(
address_n=n,
ecdsa_curve_name=ecdsa_curve_name,
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
@expect(messages.Address)
def get_authenticated_address(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
show_display: bool = False,
@ -153,13 +152,13 @@ def get_authenticated_address(
chunkify: bool = False,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetAddress(
address_n=n,
coin_name=coin_name,
@ -172,15 +171,16 @@ def get_authenticated_address(
)
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.GetOwnershipId(
address_n=n,
coin_name=coin_name,
@ -190,8 +190,9 @@ def get_ownership_id(
)
# TODO this is used by tests only
def get_ownership_proof(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -202,11 +203,11 @@ def get_ownership_proof(
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(
res = session.call(
messages.GetOwnershipProof(
address_n=n,
coin_name=coin_name,
@ -226,7 +227,7 @@ def get_ownership_proof(
@expect(messages.MessageSignature)
def sign_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
message: AnyStr,
@ -234,7 +235,7 @@ def sign_message(
no_script_type: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.SignMessage(
coin_name=coin_name,
address_n=n,
@ -247,7 +248,7 @@ def sign_message(
def verify_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
address: str,
signature: bytes,
@ -255,7 +256,7 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
resp = session.call(
messages.VerifyMessage(
address=address,
signature=signature,
@ -269,9 +270,9 @@ def verify_message(
return isinstance(resp, messages.Success)
@session
# @session
def sign_tx(
client: "TrezorClient",
session: "Session",
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
@ -319,17 +320,17 @@ def sign_tx(
setattr(signtx, name, value)
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
elif preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx)
res = session.call(signtx)
# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -388,7 +389,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = client.call(msg)
res = session.call(msg)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
@ -418,7 +419,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg))
res = session.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message")
@ -432,7 +433,7 @@ def sign_tx(
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin(
client: "TrezorClient",
session: "Session",
coordinator: str,
max_rounds: int,
max_coordinator_fee_rate: int,
@ -441,7 +442,7 @@ def authorize_coinjoin(
coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,
max_rounds=max_rounds,

View File

@ -35,8 +35,8 @@ from . import exceptions, messages, tools
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
PROTOCOL_MAGICS = {
"mainnet": 764824073,
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
@ -833,7 +833,7 @@ def get_address(
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetAddress(
address_parameters=address_parameters,
protocol_magic=protocol_magic,
@ -847,12 +847,12 @@ def get_address(
@expect(messages.CardanoPublicKey)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
show_display: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetPublicKey(
address_n=address_n,
derivation_type=derivation_type,
@ -863,12 +863,12 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetNativeScriptHash(
script=native_script,
display_format=display_format,
@ -878,7 +878,7 @@ def get_native_script_hash(
def sign_tx(
client: "TrezorClient",
session: "Session",
signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithData],
@ -915,7 +915,7 @@ def sign_tx(
signing_mode,
)
response = client.call(
response = session.call(
messages.CardanoSignTxInit(
signing_mode=signing_mode,
inputs_count=len(inputs),
@ -951,14 +951,14 @@ def sign_tx(
_get_certificates_items(certificates),
withdrawals,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data)
auxiliary_data_supplement = session.call(auxiliary_data)
if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
):
@ -971,7 +971,7 @@ def sign_tx(
auxiliary_data_supplement.__dict__
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
@ -980,24 +980,24 @@ def sign_tx(
_get_collateral_inputs_items(collateral_inputs),
required_signers,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None:
for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs:
response = client.call(reference_input)
response = session.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = []
for witness_request in witness_requests:
response = client.call(witness_request)
response = session.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append(
@ -1009,12 +1009,12 @@ def sign_tx(
}
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR

View File

@ -14,33 +14,41 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools
import logging
import os
import sys
import typing as t
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions, transport
from .. import exceptions, transport, ui
from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI
from ..messages import Capability
from ..transport import Transport
from ..transport.thp import channel_database
if TYPE_CHECKING:
LOG = logging.getLogger(__name__)
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P")
R = TypeVar("R")
class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive
if case_sensitive:
@ -48,7 +56,7 @@ class ChoiceType(click.Choice):
else:
self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -57,11 +65,48 @@ class ChoiceType(click.Choice):
return self.typemap[value]
class TrezorConnection:
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
class NewTrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -70,6 +115,29 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(self, derive_cardano: bool = False):
client = self.get_client()
if self.session_id is not None:
pass # TODO Try resume - be careful of cardano derivation settings!
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
# TODO Passphrase empty by default - ???
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "Transport":
try:
# look for transport without prefix search
@ -82,19 +150,33 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
channel_database.remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
def get_management_session(self) -> Session:
client = self.get_client()
management_session = client.get_management_session()
return management_session
@contextmanager
def client_context(self):
@ -128,7 +210,131 @@ class TrezorConnection:
# other exceptions may cause a traceback
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
# class TrezorConnection:
# def __init__(
# self,
# path: str,
# session_id: bytes | None,
# passphrase_on_host: bool,
# script: bool,
# ) -> None:
# self.path = path
# self.session_id = session_id
# self.passphrase_on_host = passphrase_on_host
# self.script = script
# def get_transport(self) -> "Transport":
# try:
# # look for transport without prefix search
# return transport.get_transport(self.path, prefix_search=False)
# except Exception:
# # most likely not found. try again below.
# pass
# # look for transport with prefix search
# # if this fails, we want the exception to bubble up to the caller
# return transport.get_transport(self.path, prefix_search=True)
# def get_ui(self) -> "TrezorClientUI":
# if self.script:
# # It is alright to return just the class object instead of instance,
# # as the ScriptUI class object itself is the implementation of TrezorClientUI
# # (ScriptUI is just a set of staticmethods)
# return ScriptUI
# else:
# return ClickUI(passphrase_on_host=self.passphrase_on_host)
# def get_client(self) -> TrezorClient:
# transport = self.get_transport()
# ui = self.get_ui()
# return TrezorClient(transport, ui=ui, session_id=self.session_id)
# @contextmanager
# def client_context(self):
# """Get a client instance as a context manager. Handle errors in a manner
# appropriate for end-users.
# Usage:
# >>> with obj.client_context() as client:
# >>> do_your_actions_here()
# """
# try:
# client = self.get_client()
# except transport.DeviceIsBusy:
# click.echo("Device is in use by another process.")
# sys.exit(1)
# except Exception:
# click.echo("Failed to find a Trezor device.")
# if self.path is not None:
# click.echo(f"Using path: {self.path}")
# sys.exit(1)
# try:
# yield client
# except exceptions.Cancelled:
# # handle cancel action
# click.echo("Action was cancelled.")
# sys.exit(1)
# except exceptions.TrezorException as e:
# # handle any Trezor-sent exceptions as user-readable
# raise click.ClickException(str(e)) from e
# # other exceptions may cause a traceback
from ..transport.session import Session
def with_cardano_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
return with_session(func=func, derive_cardano=True)
def with_session(
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_session(derive_cardano)
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return function_with_session # type: ignore [is incompatible with return type]
def with_management_session(
func: "t.Callable[Concatenate[Session, P], R]",
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_management_session(
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
session = obj.get_management_session()
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return function_with_management_session # type: ignore [is incompatible with return type]
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -139,28 +345,66 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try:
return func(client, *args, **kwargs)
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
channel_database.save_channel(client.protocol)
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility.
@ -190,14 +434,14 @@ class AliasedGroup(click.Group):
def __init__(
self,
aliases: Optional[Dict[str, click.Command]] = None,
*args: Any,
**kwargs: Any,
aliases: t.Dict[str, click.Command] | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super().__init__(*args, **kwargs)
self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click
from .. import benchmark
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
def list_names_patern(
client: "TrezorClient", pattern: Optional[str] = None
) -> List[str]:
names = list(benchmark.list_names(client).names)
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
names = list(benchmark.list_names(session).names)
if pattern is None:
return names
return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@with_session
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
@with_session
def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
for name in names:
result = benchmark.run(client, name)
result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify)
return binance.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
return binance.get_public_key(session, address_n, show_display).hex()
@cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.
"""
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64
import json
@ -22,10 +23,10 @@ import click
import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48
@ -168,15 +169,15 @@ def cli() -> None:
default=2,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
script_type: Optional[messages.InputScriptType],
script_type: messages.InputScriptType | None,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_threshold: int | None,
multisig_suffix_length: int,
chunkify: bool,
) -> str:
@ -220,7 +221,7 @@ def get_address(
multisig = None
return btc.get_address(
client,
session,
coin,
address_n,
show_display,
@ -237,9 +238,9 @@ def get_address(
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_node(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
curve: Optional[str],
@ -251,7 +252,7 @@ def get_public_node(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node(
client,
session,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
purpose: Optional[int],
@ -311,7 +312,7 @@ def _get_descriptor(
n = tools.parse_path(path)
pub = btc.get_public_node(
client,
session,
n,
show_display=show_display,
coin_name=coin,
@ -348,9 +349,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
account_type: Optional[int],
@ -360,7 +361,7 @@ def get_descriptor(
"""Get descriptor of given account."""
try:
return _get_descriptor(
client, coin, account, account_type, script_type, show_display
session, coin, account, account_type, script_type, show_display
)
except ValueError as e:
raise click.ClickException(str(e))
@ -375,8 +376,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
}
_, serialized_tx = btc.sign_tx(
client,
session,
coin,
inputs,
outputs,
@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
message: str,
@ -447,7 +448,7 @@ def sign_message(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
res = btc.sign_message(
client,
session,
coin,
address_n,
message,
@ -468,9 +469,9 @@ def sign_message(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
signature: str,
@ -480,7 +481,7 @@ def verify_message(
"""Verify message."""
signature_bytes = base64.b64decode(signature)
return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify
session, coin, address, signature_bytes, message, chunkify=chunkify
)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_cardano_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client
@with_cardano_session
def sign_tx(
client: "TrezorClient",
session: "Session",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"]
]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx(
client,
session,
signing_mode,
inputs,
outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_cardano_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes,
)
client.init_device(derive_cardano=True)
return cardano.get_address(
client,
session,
address_parameters,
protocol_magic,
network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_cardano_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
derivation_type: messages.CardanoDerivationType,
show_display: bool,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display
session, address_n, derivation_type=derivation_type, show_display=show_display
)
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
@with_cardano_session
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type
session, native_script, display_format, derivation_type=derivation_type
)

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client: "TrezorClient", size: int) -> str:
@with_session
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
return misc.get_entropy(session, size).hex()
@cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(
client,
session,
address_n,
key,
value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(
client,
session,
address_n,
key,
bytes.fromhex(value),

View File

@ -18,16 +18,15 @@ from typing import TYPE_CHECKING, Union
import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen
from . import with_client
from ..transport.session import Session
from . import with_management_session
if TYPE_CHECKING:
from . import TrezorConnection
from . import NewTrezorConnection
@click.group(name="debug")
@ -35,58 +34,58 @@ def cli() -> None:
"""Miscellaneous debug features."""
@cli.command()
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
# @cli.command()
# @click.argument("message_name_or_type")
# @click.argument("hex_data")
# @click.pass_obj
# def send_bytes(
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
# ) -> None:
# """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length.
# Message type and message data must be specified separately, due to how message
# chunking works on the transport level. Message length is calculated and sent
# automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used.
"""
if message_name_or_type.isdigit():
message_type = int(message_name_or_type)
else:
message_type = getattr(messages.MessageType, message_name_or_type)
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
# in which case the value of that enum is used.
# """
# if message_name_or_type.isdigit():
# message_type = int(message_name_or_type)
# else:
# message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.")
# if not isinstance(message_type, int):
# raise click.ClickException("Invalid message type.")
try:
message_data = bytes.fromhex(hex_data)
except Exception as e:
raise click.ClickException("Invalid hex data.") from e
# try:
# message_data = bytes.fromhex(hex_data)
# except Exception as e:
# raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport()
transport.begin_session()
transport.write(message_type, message_data)
# transport = obj.get_transport()
# transport.deprecated_begin_session()
# transport.write(message_type, message_data)
response_type, response_data = transport.read()
transport.end_session()
# response_type, response_data = transport.read()
# transport.deprecated_end_session()
click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}")
# click.echo(f"Response type: {response_type}")
# click.echo(f"Response data: {response_data.hex()}")
try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:")
click.echo(protobuf.format_message(msg))
except Exception as e:
click.echo(f"Could not parse response: {e}")
# try:
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
# click.echo("Parsed message:")
# click.echo(protobuf.format_message(msg))
# except Exception as e:
# click.echo(f"Could not parse response: {e}")
@cli.command()
@click.argument("directory", required=False)
@click.option("-s", "--stop", is_flag=True, help="Stop the recording")
@click.pass_obj
def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> None:
def record(obj: "NewTrezorConnection", directory: Union[str, None], stop: bool) -> None:
"""Record screen changes into a specified directory.
Recording can be stopped with `-s / --stop` option.
@ -95,7 +94,7 @@ def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) ->
def record_screen_from_connection(
obj: "TrezorConnection", directory: Union[str, None]
obj: "NewTrezorConnection", directory: Union[str, None]
) -> None:
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
transport = obj.get_transport()
@ -106,17 +105,17 @@ def record_screen_from_connection(
@cli.command()
@with_client
def prodtest_t1(client: "TrezorClient") -> str:
@with_management_session
def prodtest_t1(session: "Session") -> str:
"""Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
"""
return debuglink_prodtest_t1(client)
return debuglink_prodtest_t1(session)
@cli.command()
@with_client
def optiga_set_sec_max(client: "TrezorClient") -> str:
@with_management_session
def optiga_set_sec_max(session: "Session") -> str:
"""Set Optiga's security event counter to maximum."""
return debuglink_optiga_set_sec_max(client)
return debuglink_optiga_set_sec_max(session)

View File

@ -24,12 +24,12 @@ import click
import requests
from .. import debuglink, device, exceptions, messages, ui
from . import ChoiceType, with_client
from . import ChoiceType, with_management_session
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..protobuf import MessageType
from . import TrezorConnection
from ..transport.session import Session
from . import NewTrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
"scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords,
@ -64,17 +64,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True,
)
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str:
@with_management_session
def wipe(session: "Session", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader:
if not client.features.bootloader_mode:
if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
else:
click.echo("Wiping user data and firmware!")
else:
if client.features.bootloader_mode:
if features.bootloader_mode:
click.echo(
"Your device is in bootloader mode. This operation would also erase firmware."
)
@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
click.echo("Wiping user data!")
try:
return device.wipe(client)
return device.wipe(
session
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@ -103,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
@click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@with_client
@with_management_session
def load(
client: "TrezorClient",
session: "Session",
mnemonic: t.Sequence[str],
pin: str,
passphrase_protection: bool,
@ -136,7 +139,7 @@ def load(
try:
return debuglink.load_device(
client,
session,
mnemonic=list(mnemonic),
pin=pin,
passphrase_protection=passphrase_protection,
@ -171,9 +174,9 @@ def load(
)
@click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client
@with_management_session
def recover(
client: "TrezorClient",
session: "Session",
words: str,
expand: bool,
pin_protection: bool,
@ -201,7 +204,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup
return device.recover(
client,
session,
word_count=int(words),
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -222,9 +225,9 @@ def recover(
@click.option("-s", "--skip-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@with_client
@with_management_session
def setup(
client: "TrezorClient",
session: "Session",
strength: int | None,
passphrase_protection: bool,
pin_protection: bool,
@ -241,7 +244,7 @@ def setup(
BT = messages.BackupType
if backup_type is None:
if client.version >= (2, 7, 1):
if session.version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable
else:
@ -251,10 +254,10 @@ def setup(
if (
backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities
and messages.Capability.Shamir not in session.features.capabilities
) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities
and messages.Capability.ShamirGroups not in session.features.capabilities
):
click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n"
@ -262,7 +265,7 @@ def setup(
)
return device.reset(
client,
session,
strength=strength,
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -277,23 +280,21 @@ def setup(
@cli.command()
@click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client
@with_management_session
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (),
) -> str:
"""Perform device seed backup."""
return device.backup(client, group_threshold, groups)
return device.backup(session, group_threshold, groups)
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
@with_management_session
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -307,36 +308,36 @@ def sd_protect(
off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one.
"""
if client.features.model == "1":
if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation)
return device.sd_protect(session, operation)
@cli.command()
@click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
def reboot_to_bootloader(obj: "NewTrezorConnection") -> str:
"""Reboot device into bootloader mode.
Currently only supported on Trezor Model One.
"""
# avoid using @with_client because it closes the session afterwards,
# avoid using @with_management_session because it closes the session afterwards,
# which triggers double prompt on device
with obj.client_context() as client:
return device.reboot_to_bootloader(client)
return device.reboot_to_bootloader(client.get_management_session())
@cli.command()
@with_client
def tutorial(client: "TrezorClient") -> str:
@with_management_session
def tutorial(session: "Session") -> str:
"""Show on-device tutorial."""
return device.show_device_tutorial(client)
return device.show_device_tutorial(session)
@cli.command()
@with_client
def unlock_bootloader(client: "TrezorClient") -> str:
@with_management_session
def unlock_bootloader(session: "Session") -> str:
"""Unlocks bootloader. Irreversible."""
return device.unlock_bootloader(client)
return device.unlock_bootloader(session)
@cli.command()
@ -347,11 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
type=int,
help="Dialog expiry in seconds.",
)
@with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
@with_management_session
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
"""Show a "Do not disconnect" dialog."""
if enable is False:
return device.set_busy(client, None)
return device.set_busy(session, None)
if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -361,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
)
return device.set_busy(client, expiry * 1000)
return device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -381,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True,
help="Do not check intermediate certificates against the whitelist.",
)
@with_client
@with_management_session
def authenticate(
client: "TrezorClient",
session: "Session",
hex_challenge: str | None,
root: t.BinaryIO | None,
raw: bool | None,
@ -408,7 +409,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge)
if raw:
msg = device.authenticate(client, challenge)
msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -456,14 +457,14 @@ def authenticate(
else:
whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower()
model=session.model.internal_name.lower()
)
).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try:
authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
)
except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)
address_n = tools.parse_path(address)
return eos.sign_tx(
client,
session,
address_n,
tx_json["transaction"],
tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions
from . import with_client
from . import with_session
if TYPE_CHECKING:
import web3
from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify)
return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
@with_session
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
result = ethereum.get_public_node(session, address_n, show_display=show_display)
return {
"node": {
"depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
chain_id: int,
address: str,
amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address)
from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network
session, address_n, encoded_network=encoded_network
)
if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
session,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price
assert gas_price is not None
sig = ethereum.sign_tx(
client,
session,
n=address_n,
tx_type=tx_type,
nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool
session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = {
"message": message,
"address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation",
)
@click.argument("file", type=click.File("r"))
@with_client
@with_session
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read())
ret = ethereum.sign_typed_data(
client,
session,
address_n,
data,
metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: str,
message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify
session, address, signature_bytes, message, chunkify=chunkify
)
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex")
@click.argument("message_hash_hex")
@with_client
@with_session
def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]:
"""
Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network
session, address_n, domain_hash, message_hash, network
)
output = {
"domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list")
@with_client
def credentials_list(client: "TrezorClient") -> None:
@with_session
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
creds = fido.list_credentials(session)
for cred in creds:
click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
@with_session
def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> str:
@with_session
def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index."""
return fido.remove_credential(client, index)
return fido.remove_credential(session, index)
#
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> str:
@with_session
def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value."""
return fido.set_counter(client, counter)
return fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client: "TrezorClient") -> int:
@with_session
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(client)
return fido.get_next_counter(session)

View File

@ -37,11 +37,12 @@ import requests
from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models
from ..models import TrezorModel
from . import ChoiceType, with_client
from . import ChoiceType, with_management_session
if TYPE_CHECKING:
from ..client import TrezorClient
from . import TrezorConnection
from ..transport.session import Session
from . import NewTrezorConnection
MODEL_CHOICE = ChoiceType(
{
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader.
"""
f = client.features
version = (f.major_version, f.minor_version, f.patch_version)
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
features = client.features
version = client.version
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2
@ -306,25 +307,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version
(higher than the specified one, if existing).
"""
features = client.features
model = client.model
if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features)
bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
releases = get_all_firmware_releases(model, bitcoin_only, beta)
highest_version = releases[0]["version"]
if version:
want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version:
if want_version[0] != features.major_version:
click.echo(
f"Warning: Trezor {client.model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})"
f"Warning: Trezor {model.name} firmware version should be "
f"{features.major_version}.X.Y (requested: {version})"
)
else:
want_version = highest_version
@ -359,8 +361,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal
# compatible version first
# Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version]
if f.bootloader_mode:
client_version = client.version
if features.bootloader_mode:
key_to_compare = "min_bootloader_version"
else:
key_to_compare = "min_firmware_version"
@ -447,11 +449,11 @@ def extract_embedded_fw(
def upload_firmware_into_device(
client: "TrezorClient",
session: "Session",
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
f = client.features
f = session.features
try:
if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest
@ -461,7 +463,7 @@ def upload_firmware_into_device(
with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False
) as bar:
firmware.update(client, firmware_data, bar.update)
firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled:
click.echo("Update aborted on device.")
except exceptions.TrezorException as e:
@ -519,7 +521,7 @@ def cli() -> None:
@click.pass_obj
# fmt: on
def verify(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
filename: BinaryIO,
check_device: bool,
fingerprint: Optional[str],
@ -564,7 +566,7 @@ def verify(
@click.pass_obj
# fmt: on
def download(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
output: Optional[BinaryIO],
model: Optional[TrezorModel],
version: Optional[str],
@ -630,7 +632,7 @@ def download(
# fmt: on
@click.pass_obj
def update(
obj: "TrezorConnection",
obj: "NewTrezorConnection",
filename: Optional[BinaryIO],
url: Optional[str],
version: Optional[str],
@ -654,6 +656,7 @@ def update(
against data.trezor.io information, if available.
"""
with obj.client_context() as client:
management_session = client.get_management_session()
if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.")
sys.exit(1)
@ -709,7 +712,7 @@ def update(
if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader(
client,
management_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size],
language_data=language_data,
@ -719,7 +722,7 @@ def update(
click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded."
)
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(management_session)
click.echo("Waiting for bootloader...")
while True:
@ -735,13 +738,15 @@ def update(
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data)
upload_firmware_into_device(
session=client.get_management_session(), firmware_data=firmware_data
)
@cli.command()
@click.argument("hex_challenge", required=False)
@with_client
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
@with_management_session
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex()
return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click
from .. import messages, monero, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify)
return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET,
)
@with_client
@with_session
def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type)
res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests
from .. import nem, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
network: int,
show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify)
return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
file: TextIO,
broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
"""
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify)
return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:")
click.echo(result.signature.hex())
click.echo()

View File

@ -24,10 +24,11 @@ import click
import requests
from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client
from ..transport.session import Session
from . import AliasedGroup, ChoiceType, with_management_session
if TYPE_CHECKING:
from ..client import TrezorClient
pass
try:
from PIL import Image
@ -180,18 +181,18 @@ def cli() -> None:
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
@with_management_session
def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set, change or remove PIN."""
# Remove argument is there for backwards compatibility
return device.change_pin(client, remove=_should_remove(enable, remove))
return device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
@with_management_session
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
removed and the device will be reset to factory defaults.
"""
# Remove argument is there for backwards compatibility
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
return device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command()
# keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client: "TrezorClient", label: str) -> str:
@with_management_session
def label(session: "Session", label: str) -> str:
"""Set new device label."""
return device.apply_settings(client, label=label)
return device.apply_settings(session, label=label)
@cli.command()
@with_client
def brightness(client: "TrezorClient") -> str:
@with_management_session
def brightness(session: "Session") -> str:
"""Set display brightness."""
return device.set_brightness(client)
return device.set_brightness(session)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
@with_management_session
def haptic_feedback(session: "Session", enable: bool) -> str:
"""Enable or disable haptic feedback."""
return device.apply_settings(client, haptic_feedback=enable)
return device.apply_settings(session, haptic_feedback=enable)
@cli.command()
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
)
@click.option("-d/-D", "--display/--no-display", default=None)
@with_client
@with_management_session
def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> str:
"""Set new language with translations."""
if remove != (path_or_url is None):
@ -260,29 +261,29 @@ def language(
f"Failed to load translations from {path_or_url}"
) from None
return device.change_language(
client, language_data=language_data, show_display=display
session, language_data=language_data, show_display=display
)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
@with_management_session
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
return device.apply_settings(client, display_rotation=rotation)
return device.apply_settings(session, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
@with_management_session
def auto_lock_delay(session: "Session", delay: str) -> str:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:]
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@with_client
def flags(client: "TrezorClient", flags: str) -> str:
@with_management_session
def flags(session: "Session", flags: str) -> str:
"""Set device flags."""
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
flags_int = int(flags, 16)
else:
flags_int = int(flags)
return device.apply_flags(client, flags=flags_int)
return device.apply_flags(session, flags=flags_int)
@cli.command()
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
)
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
@with_management_session
def homescreen(session: "Session", filename: str, quality: int) -> str:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file")
if client.features.model == "1":
if session.features.model == "1":
img = image_to_t1(path)
else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 240
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 240
)
img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width
height = client.features.homescreen_height
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = session.features.homescreen_width
height = session.features.homescreen_height
if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True)
elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None
session.features.homescreen_format == messages.HomescreenFormat.Toif
or session.features.homescreen_format is None
):
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 144
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 144
)
img = image_to_toif(path, width, height, False)
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"Unknown image format requested by the device."
)
return device.apply_settings(client, homescreen=img)
return device.apply_settings(session, homescreen=img)
@cli.command()
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
)
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client
@with_management_session
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> str:
"""Set safety check level.
@ -392,18 +393,18 @@ def safety_checks(
"""
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways
return device.apply_settings(client, safety_checks=level)
return device.apply_settings(session, safety_checks=level)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client: "TrezorClient", enable: bool) -> str:
@with_management_session
def experimental_features(session: "Session", enable: bool) -> str:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, experimental_features=enable)
return device.apply_settings(session, experimental_features=enable)
#
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
@with_management_session
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
"""Enable passphrase."""
if client.features.passphrase_protection is not True:
if session.features.passphrase_protection is not True:
use_passphrase = True
else:
use_passphrase = None
return device.apply_settings(
client,
session,
use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device,
)
@passphrase.command(name="off")
@with_client
def passphrase_off(client: "TrezorClient") -> str:
@with_management_session
def passphrase_off(session: "Session") -> str:
"""Disable passphrase."""
return device.apply_settings(client, use_passphrase=False)
return device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility
@ -457,10 +458,10 @@ passphrase.aliases = {
@passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
@with_management_session
def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
"""Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, hide_passphrase_from_host=hide)
return device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import messages, solana, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
) -> messages.SolanaPublicKey:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
return solana.get_public_key(session, address_n, show_display)
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
chunkify: bool,
) -> messages.SolanaAddress:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
return solana.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
)
return solana.sign_tx(
client,
session,
address_n,
bytes.fromhex(serialized_tx),
additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
try:
from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify)
return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature)

Some files were not shown because too many files have changed in this diff Show More