mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-12 07:32:48 +00:00
feat: new THP
This commit is contained in:
parent
6cbf5e4064
commit
aaaeb3abca
@ -307,6 +307,7 @@ core unix frozen debug build:
|
||||
needs: []
|
||||
variables:
|
||||
PYOPT: "0"
|
||||
THP: "1"
|
||||
script:
|
||||
- $NIX_SHELL --run "poetry run make -C core build_unix_frozen"
|
||||
artifacts:
|
||||
|
@ -39,6 +39,8 @@ message Failure {
|
||||
Failure_PinMismatch = 12;
|
||||
Failure_WipeCodeMismatch = 13;
|
||||
Failure_InvalidSession = 14;
|
||||
Failure_ThpUnallocatedSession=15;
|
||||
Failure_InvalidProtocol=16;
|
||||
Failure_FirmwareError = 99;
|
||||
}
|
||||
}
|
||||
|
@ -110,6 +110,8 @@ message DebugLinkGetState {
|
||||
// trezor-core only - wait until current layout changes
|
||||
// changed in 2.6.4: multiple wait types instead of true/false.
|
||||
optional DebugWaitType wait_layout = 3 [default=IMMEDIATE];
|
||||
// THP only - it is used to get information from specified channel
|
||||
optional bytes thp_channel_id=4;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -130,6 +132,9 @@ message DebugLinkState {
|
||||
optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow
|
||||
optional management.BackupType mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39)
|
||||
repeated string tokens = 13; // current layout represented as a list of string tokens
|
||||
optional uint32 thp_pairing_code_entry_code = 14;
|
||||
optional bytes thp_pairing_code_qr_code = 15;
|
||||
optional bytes thp_pairing_code_nfc_unidirectional = 16;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -9,6 +9,218 @@ import "options.proto";
|
||||
|
||||
option (include_in_bitcoin_only) = true;
|
||||
|
||||
/**
|
||||
* Mapping between Trezor wire identifier (uint) and a Thp protobuf message
|
||||
*/
|
||||
enum ThpMessageType {
|
||||
reserved 0 to 999; // Values reserved by other messages, see messages.proto
|
||||
|
||||
ThpMessageType_ThpCreateNewSession = 1000[(bitcoin_only)=true, (channel_in) = true];
|
||||
ThpMessageType_ThpNewSession = 1001[(bitcoin_only)=true, (channel_out) = true];
|
||||
ThpMessageType_ThpStartPairingRequest = 1008 [(bitcoin_only) = true, (pairing_in) = true];
|
||||
ThpMessageType_ThpPairingPreparationsFinished = 1009 [(bitcoin_only) = true, (pairing_out) = true];
|
||||
ThpMessageType_ThpCredentialRequest = 1010 [(bitcoin_only) = true, (pairing_in) = true];
|
||||
ThpMessageType_ThpCredentialResponse = 1011 [(bitcoin_only) = true, (pairing_out) = true];
|
||||
ThpMessageType_ThpEndRequest = 1012 [(bitcoin_only) = true, (pairing_in) = true];
|
||||
ThpMessageType_ThpEndResponse = 1013[(bitcoin_only) = true, (pairing_out) = true];
|
||||
ThpMessageType_ThpCodeEntryCommitment = 1016[(bitcoin_only)=true, (pairing_out) = true];
|
||||
ThpMessageType_ThpCodeEntryChallenge = 1017[(bitcoin_only)=true, (pairing_in) = true];
|
||||
ThpMessageType_ThpCodeEntryCpaceHost = 1018[(bitcoin_only)=true, (pairing_in) = true];
|
||||
ThpMessageType_ThpCodeEntryCpaceTrezor = 1019[(bitcoin_only)=true, (pairing_out) = true];
|
||||
ThpMessageType_ThpCodeEntryTag = 1020[(bitcoin_only)=true, (pairing_in) = true];
|
||||
ThpMessageType_ThpCodeEntrySecret = 1021[(bitcoin_only)=true, (pairing_out) = true];
|
||||
ThpMessageType_ThpQrCodeTag = 1024[(bitcoin_only)=true, (pairing_in) = true];
|
||||
ThpMessageType_ThpQrCodeSecret = 1025[(bitcoin_only)=true, (pairing_out) = true];
|
||||
ThpMessageType_ThpNfcUnidirectionalTag = 1032[(bitcoin_only)=true, (pairing_in) = true];
|
||||
ThpMessageType_ThpNfcUnidirectionalSecret = 1033[(bitcoin_only)=true, (pairing_in) = true];
|
||||
|
||||
reserved 1100 to 2147483647; // Values reserved by other messages, see messages.proto
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Numeric identifiers of pairing methods.
|
||||
* @embed
|
||||
*/
|
||||
enum ThpPairingMethod {
|
||||
NoMethod = 1; // Trust without MITM protection.
|
||||
CodeEntry = 2; // User types code diplayed on Trezor into the host application.
|
||||
QrCode = 3; // User scans code displayed on Trezor into host application.
|
||||
NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
|
||||
}
|
||||
|
||||
/**
|
||||
* @embed
|
||||
*/
|
||||
message ThpDeviceProperties {
|
||||
optional string internal_model = 1; // Internal model name e.g. "T2B1".
|
||||
optional uint32 model_variant = 2; // Encodes the device properties such as color.
|
||||
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
|
||||
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
|
||||
repeated ThpPairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
|
||||
}
|
||||
|
||||
/**
|
||||
* @embed
|
||||
*/
|
||||
message ThpHandshakeCompletionReqNoisePayload {
|
||||
optional bytes host_pairing_credential = 1; // Host's pairing credential
|
||||
repeated ThpPairingMethod pairing_methods = 2; // The pairing methods chosen by the host
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Ask device for a new session with given passphrase.
|
||||
* @start
|
||||
* @next ThpNewSession
|
||||
*/
|
||||
message ThpCreateNewSession{
|
||||
optional string passphrase = 1;
|
||||
optional bool on_device = 2; // User wants to enter passphrase on the device
|
||||
optional bool derive_cardano = 3; // If True, Cardano keys will be derived. Ignored with BTC-only
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Contains session_id of the newly created session.
|
||||
* @end
|
||||
*/
|
||||
message ThpNewSession{
|
||||
optional uint32 new_session_id = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Start pairing process.
|
||||
* @start
|
||||
* @next ThpCodeEntryCommitment
|
||||
* @next ThpPairingPreparationsFinished
|
||||
*/
|
||||
message ThpStartPairingRequest{
|
||||
optional string host_name = 1; // Human-readable host name
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Pairing is ready for user input / OOB communication.
|
||||
* @next ThpCodeEntryCpace
|
||||
* @next ThpQrCodeTag
|
||||
* @next ThpNfcUnidirectionalTag
|
||||
*/
|
||||
message ThpPairingPreparationsFinished{
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: If Code Entry is an allowed pairing option, Trezor responds with a commitment.
|
||||
* @next ThpCodeEntryChallenge
|
||||
*/
|
||||
message ThpCodeEntryCommitment {
|
||||
optional bytes commitment = 1; // SHA-256 of Trezor's random 32-byte secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Host responds to Trezor's Code Entry commitment with a challenge.
|
||||
* @next ThpPairingPreparationsFinished
|
||||
*/
|
||||
message ThpCodeEntryChallenge {
|
||||
optional bytes challenge = 1; // host's random 32-byte challenge
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected Code Entry option in Host. Host starts CPACE protocol with Trezor.
|
||||
* @next ThpCodeEntryCpaceTrezor
|
||||
*/
|
||||
message ThpCodeEntryCpaceHost {
|
||||
optional bytes cpace_host_public_key = 1; // Host's ephemeral CPace public key
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor continues with the CPACE protocol.
|
||||
* @next ThpCodeEntryTag
|
||||
*/
|
||||
message ThpCodeEntryCpaceTrezor {
|
||||
optional bytes cpace_trezor_public_key = 1; // Trezor's ephemeral CPace public key
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Host continues with the CPACE protocol.
|
||||
* @next ThpCodeEntrySecret
|
||||
*/
|
||||
message ThpCodeEntryTag {
|
||||
optional bytes tag = 2; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor finishes the CPACE protocol.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpCodeEntrySecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected QR Code pairing option. Host sends a QR Tag.
|
||||
* @next ThpQrCodeSecret
|
||||
*/
|
||||
message ThpQrCodeTag {
|
||||
optional bytes tag = 1; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor sends the QR secret.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpQrCodeSecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: User selected Unidirectional NFC pairing option. Host sends an Unidirectional NFC Tag.
|
||||
* @next ThpNfcUnidirectionalSecret
|
||||
*/
|
||||
message ThpNfcUnidirectionalTag {
|
||||
optional bytes tag = 1; // SHA-256 of shared secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor sends the Unidirectioal NFC secret.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpNfcUnidirectionalSecret {
|
||||
optional bytes secret = 1; // Trezor's secret
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Host requests issuance of a new pairing credential.
|
||||
* @start
|
||||
* @next ThpCredentialResponse
|
||||
*/
|
||||
message ThpCredentialRequest {
|
||||
optional bytes host_static_pubkey = 1; // Host's static public key used in the handshake.
|
||||
}
|
||||
|
||||
/**
|
||||
* Response: Trezor issues a new pairing credential.
|
||||
* @next ThpCredentialRequest
|
||||
* @next ThpEndRequest
|
||||
*/
|
||||
message ThpCredentialResponse {
|
||||
optional bytes trezor_static_pubkey = 1; // Trezor's static public key used in the handshake.
|
||||
optional bytes credential = 2; // The pairing credential issued by the Trezor to the host.
|
||||
}
|
||||
|
||||
/**
|
||||
* Request: Host requests transition to the encrypted traffic phase.
|
||||
* @start
|
||||
* @next ThpEndResponse
|
||||
*/
|
||||
message ThpEndRequest {}
|
||||
|
||||
/**
|
||||
* Response: Trezor approves transition to the encrypted traffic phase
|
||||
* @end
|
||||
*/
|
||||
message ThpEndResponse {}
|
||||
|
||||
/**
|
||||
* Only for internal use.
|
||||
* @embed
|
||||
|
@ -37,6 +37,10 @@ The convention to achieve this is as follows:
|
||||
optional bool wire_tiny = 50006; // message is handled by Trezor when the USB stack is in tiny mode
|
||||
optional bool wire_bootloader = 50007; // message is only handled by Trezor Bootloader
|
||||
optional bool wire_no_fsm = 50008; // message is not handled by Trezor unless the USB stack is in tiny mode
|
||||
optional bool channel_in = 50009;
|
||||
optional bool channel_out = 50010;
|
||||
optional bool pairing_in = 50011;
|
||||
optional bool pairing_out = 50012;
|
||||
|
||||
optional bool bitcoin_only = 60000; // enum value is available on BITCOIN_ONLY build
|
||||
// (messages not marked bitcoin_only will be EXCLUDED)
|
||||
|
@ -62,6 +62,7 @@ INT_TYPES = (
|
||||
)
|
||||
|
||||
MESSAGE_TYPE_ENUM = "MessageType"
|
||||
THP_MESSAGE_TYPE_ENUM = "ThpMessageType"
|
||||
|
||||
LengthDelimited = c.Struct(
|
||||
"len" / c.VarInt,
|
||||
@ -239,6 +240,9 @@ class ProtoMessage:
|
||||
@classmethod
|
||||
def from_message(cls, descriptor: "Descriptor", message):
|
||||
message_type = find_by_name(descriptor.message_type_enum.value, message.name)
|
||||
thp_message_type = None
|
||||
if not isinstance(descriptor.thp_message_type_enum,tuple):
|
||||
thp_message_type = find_by_name(descriptor.thp_message_type_enum.value, message.name)
|
||||
# use extensions set on the message_type entry (if any)
|
||||
extensions = descriptor.get_extensions(message_type)
|
||||
# override with extensions set on the message itself
|
||||
@ -248,6 +252,8 @@ class ProtoMessage:
|
||||
wire_type = extensions["wire_type"]
|
||||
elif message_type is not None:
|
||||
wire_type = message_type.number
|
||||
elif thp_message_type is not None:
|
||||
wire_type = thp_message_type.number
|
||||
else:
|
||||
wire_type = None
|
||||
|
||||
@ -351,10 +357,13 @@ class Descriptor:
|
||||
]
|
||||
logging.debug(f"found {len(self.files)} bitcoin-only files")
|
||||
|
||||
# find message_type enum
|
||||
# find message_type and thp_message_type enum
|
||||
top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files)
|
||||
self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ())
|
||||
top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files)
|
||||
self.thp_message_type_enum = find_by_name(top_level_enums, THP_MESSAGE_TYPE_ENUM, ())
|
||||
self.convert_enum_value_names(self.message_type_enum)
|
||||
self.convert_enum_value_names(self.thp_message_type_enum)
|
||||
|
||||
# find messages and enums
|
||||
self.messages = []
|
||||
@ -423,6 +432,8 @@ class Descriptor:
|
||||
self._nested_types_from_message(nested.orig)
|
||||
|
||||
def convert_enum_value_names(self, enum):
|
||||
if isinstance(enum,tuple):
|
||||
return
|
||||
for value in enum.value:
|
||||
value.name = strip_enum_prefix(enum.name, value.name)
|
||||
|
||||
@ -558,6 +569,8 @@ class RustBlobRenderer:
|
||||
enums = []
|
||||
cursor = 0
|
||||
for enum in sorted(self.descriptor.enums, key=lambda e: e.name):
|
||||
if enum.name == "MessageType":
|
||||
continue
|
||||
self.enum_map[enum.name] = cursor
|
||||
enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value))
|
||||
enums.append(enum_blob)
|
||||
|
@ -309,6 +309,12 @@ build_unix_frozen: templates build_cross ## build unix port with frozen modules
|
||||
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 \
|
||||
BENCHMARK="$(BENCHMARK)"
|
||||
|
||||
build_unix_frozen_debug: templates build_cross ## build unix port with frozen modules and DEBUG (PYOPT="0")
|
||||
$(SCONS) CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
|
||||
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" THP="$(THP)"\
|
||||
PYOPT="0" BITCOIN_ONLY="$(BITCOIN_ONLY)" TREZOR_EMULATOR_ASAN="$(ADDRESS_SANITIZER)" \
|
||||
TREZOR_MEMPERF="$(TREZOR_MEMPERF)" TREZOR_EMULATOR_FROZEN=1 BENCHMARK="$(BENCHMARK)"
|
||||
|
||||
build_unix_debug: templates ## build unix port
|
||||
$(SCONS) --max-drift=1 CFLAGS="$(CFLAGS)" $(UNIX_BUILD_DIR)/trezor-emu-core $(UNIX_PORT_OPTS) \
|
||||
TREZOR_MODEL="$(TREZOR_MODEL)" CMAKELISTS="$(CMAKELISTS)" \
|
||||
|
@ -561,6 +561,10 @@ if FROZEN:
|
||||
] if not EVERYTHING else []
|
||||
))
|
||||
|
||||
if not THP or PYOPT == '0':
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||
if THP:
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||
|
@ -629,6 +629,10 @@ if FROZEN:
|
||||
] if not EVERYTHING else []
|
||||
))
|
||||
|
||||
if not THP or PYOPT == '0':
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py'))
|
||||
if THP:
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/thp/*.py'))
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py'))
|
||||
|
||||
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py',
|
||||
|
@ -1731,7 +1731,7 @@ pub static mp_module_trezorui2: Module = obj_module! {
|
||||
/// """Calls drop on contents of the root component."""
|
||||
///
|
||||
/// class UiResult:
|
||||
/// """Result of an UI operation."""
|
||||
/// """Result of a UI operation."""
|
||||
/// pass
|
||||
///
|
||||
/// mock:global
|
||||
|
File diff suppressed because it is too large
Load Diff
62
core/src/all_modules.py
generated
62
core/src/all_modules.py
generated
@ -47,6 +47,12 @@ storage
|
||||
import storage
|
||||
storage.cache
|
||||
import storage.cache
|
||||
storage.cache_codec
|
||||
import storage.cache_codec
|
||||
storage.cache_common
|
||||
import storage.cache_common
|
||||
storage.cache_thp
|
||||
import storage.cache_thp
|
||||
storage.common
|
||||
import storage.common
|
||||
storage.debug
|
||||
@ -201,12 +207,20 @@ trezor.utils
|
||||
import trezor.utils
|
||||
trezor.wire
|
||||
import trezor.wire
|
||||
trezor.wire.codec_v1
|
||||
import trezor.wire.codec_v1
|
||||
trezor.wire.codec
|
||||
import trezor.wire.codec
|
||||
trezor.wire.codec.codec_context
|
||||
import trezor.wire.codec.codec_context
|
||||
trezor.wire.codec.codec_v1
|
||||
import trezor.wire.codec.codec_v1
|
||||
trezor.wire.context
|
||||
import trezor.wire.context
|
||||
trezor.wire.errors
|
||||
import trezor.wire.errors
|
||||
trezor.wire.message_handler
|
||||
import trezor.wire.message_handler
|
||||
trezor.wire.protocol_common
|
||||
import trezor.wire.protocol_common
|
||||
trezor.workflow
|
||||
import trezor.workflow
|
||||
apps
|
||||
@ -309,6 +323,8 @@ apps.common.backup
|
||||
import apps.common.backup
|
||||
apps.common.backup_types
|
||||
import apps.common.backup_types
|
||||
apps.common.cache
|
||||
import apps.common.cache
|
||||
apps.common.cbor
|
||||
import apps.common.cbor
|
||||
apps.common.coininfo
|
||||
@ -401,10 +417,52 @@ apps.workflow_handlers
|
||||
import apps.workflow_handlers
|
||||
|
||||
if utils.USE_THP:
|
||||
trezor.enums.ThpMessageType
|
||||
import trezor.enums.ThpMessageType
|
||||
trezor.enums.ThpPairingMethod
|
||||
import trezor.enums.ThpPairingMethod
|
||||
trezor.wire.thp
|
||||
import trezor.wire.thp
|
||||
trezor.wire.thp.alternating_bit_protocol
|
||||
import trezor.wire.thp.alternating_bit_protocol
|
||||
trezor.wire.thp.channel
|
||||
import trezor.wire.thp.channel
|
||||
trezor.wire.thp.channel_manager
|
||||
import trezor.wire.thp.channel_manager
|
||||
trezor.wire.thp.checksum
|
||||
import trezor.wire.thp.checksum
|
||||
trezor.wire.thp.control_byte
|
||||
import trezor.wire.thp.control_byte
|
||||
trezor.wire.thp.cpace
|
||||
import trezor.wire.thp.cpace
|
||||
trezor.wire.thp.crypto
|
||||
import trezor.wire.thp.crypto
|
||||
trezor.wire.thp.interface_manager
|
||||
import trezor.wire.thp.interface_manager
|
||||
trezor.wire.thp.memory_manager
|
||||
import trezor.wire.thp.memory_manager
|
||||
trezor.wire.thp.pairing_context
|
||||
import trezor.wire.thp.pairing_context
|
||||
trezor.wire.thp.received_message_handler
|
||||
import trezor.wire.thp.received_message_handler
|
||||
trezor.wire.thp.session_context
|
||||
import trezor.wire.thp.session_context
|
||||
trezor.wire.thp.session_manager
|
||||
import trezor.wire.thp.session_manager
|
||||
trezor.wire.thp.thp_main
|
||||
import trezor.wire.thp.thp_main
|
||||
trezor.wire.thp.transmission_loop
|
||||
import trezor.wire.thp.transmission_loop
|
||||
trezor.wire.thp.writer
|
||||
import trezor.wire.thp.writer
|
||||
apps.thp
|
||||
import apps.thp
|
||||
apps.thp.create_new_session
|
||||
import apps.thp.create_new_session
|
||||
apps.thp.credential_manager
|
||||
import apps.thp.credential_manager
|
||||
apps.thp.pairing
|
||||
import apps.thp.pairing
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
trezor.enums.BinanceOrderSide
|
||||
|
@ -1,23 +1,20 @@
|
||||
from typing import Iterable
|
||||
|
||||
import storage.cache as storage_cache
|
||||
from storage.cache_common import (
|
||||
APP_COMMON_AUTHORIZATION_DATA,
|
||||
APP_COMMON_AUTHORIZATION_TYPE,
|
||||
)
|
||||
from trezor import protobuf
|
||||
from trezor.enums import MessageType
|
||||
from trezor.wire import context
|
||||
|
||||
WIRE_TYPES: dict[int, tuple[int, ...]] = {
|
||||
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
|
||||
}
|
||||
|
||||
APP_COMMON_AUTHORIZATION_DATA = (
|
||||
storage_cache.APP_COMMON_AUTHORIZATION_DATA
|
||||
) # global_import_cache
|
||||
APP_COMMON_AUTHORIZATION_TYPE = (
|
||||
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
|
||||
) # global_import_cache
|
||||
|
||||
|
||||
def is_set() -> bool:
|
||||
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None
|
||||
return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
|
||||
|
||||
|
||||
def set(auth_message: protobuf.MessageType) -> None:
|
||||
@ -29,27 +26,27 @@ def set(auth_message: protobuf.MessageType) -> None:
|
||||
# (because only wire-level messages have wire_type, which we use as identifier)
|
||||
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
||||
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
|
||||
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
|
||||
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
|
||||
context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
|
||||
context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
|
||||
|
||||
|
||||
def get() -> protobuf.MessageType | None:
|
||||
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
if not stored_auth_type:
|
||||
return None
|
||||
|
||||
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
return protobuf.load_message_buffer(buffer, stored_auth_type)
|
||||
|
||||
|
||||
def is_set_any_session(auth_type: MessageType) -> bool:
|
||||
return auth_type in storage_cache.get_int_all_sessions(
|
||||
return auth_type in context.cache_get_int_all_sessions(
|
||||
APP_COMMON_AUTHORIZATION_TYPE
|
||||
)
|
||||
|
||||
|
||||
def get_wire_types() -> Iterable[int]:
|
||||
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
if stored_auth_type is None:
|
||||
return ()
|
||||
|
||||
@ -57,5 +54,5 @@ def get_wire_types() -> Iterable[int]:
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
|
||||
context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
context.cache_delete(APP_COMMON_AUTHORIZATION_DATA)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
from trezor.crypto import bip32
|
||||
from trezor.wire import DataError
|
||||
|
||||
@ -172,7 +173,10 @@ async def get_keychain(
|
||||
) -> Keychain:
|
||||
from .seed import get_seed
|
||||
|
||||
seed = await get_seed()
|
||||
if not utils.USE_THP:
|
||||
pass
|
||||
# try to ask for passphrase here
|
||||
seed = get_seed()
|
||||
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
||||
return keychain
|
||||
|
||||
|
@ -1,83 +1,122 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.device as storage_device
|
||||
from trezor import utils
|
||||
from trezor.wire import DataError
|
||||
|
||||
_MAX_PASSPHRASE_LEN = const(50)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return storage_device.is_passphrase_enabled()
|
||||
|
||||
|
||||
async def get() -> str:
|
||||
from trezor import workflow
|
||||
|
||||
async def get_passphrase(msg: ThpCreateNewSession) -> str:
|
||||
if not is_enabled():
|
||||
return ""
|
||||
|
||||
if msg.on_device or storage_device.get_passphrase_always_on_device():
|
||||
passphrase = await _get_on_device()
|
||||
else:
|
||||
workflow.close_others() # request exclusive UI access
|
||||
if storage_device.get_passphrase_always_on_device():
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
passphrase = msg.passphrase or ""
|
||||
if passphrase:
|
||||
await _handle_displaying_passphrase_from_host(passphrase)
|
||||
|
||||
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
|
||||
else:
|
||||
passphrase = await _request_on_host()
|
||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
|
||||
|
||||
return passphrase
|
||||
|
||||
|
||||
async def _request_on_host() -> str:
|
||||
from trezor import TR
|
||||
from trezor.messages import PassphraseAck, PassphraseRequest
|
||||
from trezor.ui.layouts import request_passphrase_on_host
|
||||
from trezor.wire.context import call
|
||||
|
||||
request_passphrase_on_host()
|
||||
|
||||
request = PassphraseRequest()
|
||||
ack = await call(request, PassphraseAck)
|
||||
passphrase = ack.passphrase # local_cache_attribute
|
||||
|
||||
if ack.on_device:
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
if passphrase is not None:
|
||||
raise DataError("Passphrase provided when it should not be")
|
||||
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
|
||||
|
||||
if passphrase is None:
|
||||
raise DataError(
|
||||
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
|
||||
)
|
||||
|
||||
# non-empty passphrase
|
||||
if passphrase:
|
||||
from trezor.ui.layouts import confirm_action, confirm_blob
|
||||
|
||||
# We want to hide the passphrase, or show it, according to settings.
|
||||
if storage_device.get_hide_passphrase_from_host():
|
||||
await confirm_action(
|
||||
"passphrase_host1_hidden",
|
||||
TR.passphrase__wallet,
|
||||
description=TR.passphrase__from_host_not_shown,
|
||||
prompt_screen=True,
|
||||
prompt_title=TR.passphrase__access_wallet,
|
||||
)
|
||||
else:
|
||||
await confirm_action(
|
||||
"passphrase_host1",
|
||||
TR.passphrase__wallet,
|
||||
description=TR.passphrase__next_screen_will_show_passphrase,
|
||||
verb=TR.buttons__continue,
|
||||
)
|
||||
|
||||
await confirm_blob(
|
||||
"passphrase_host2",
|
||||
TR.passphrase__title_confirm,
|
||||
passphrase,
|
||||
)
|
||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
|
||||
|
||||
return passphrase
|
||||
|
||||
|
||||
async def _get_on_device() -> str:
|
||||
from trezor import workflow
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
workflow.close_others() # request exclusive UI access
|
||||
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
|
||||
|
||||
return passphrase
|
||||
|
||||
|
||||
async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
|
||||
from trezor import TR
|
||||
from trezor.ui.layouts import confirm_action, confirm_blob
|
||||
|
||||
# We want to hide the passphrase, or show it, according to settings.
|
||||
if storage_device.get_hide_passphrase_from_host():
|
||||
await confirm_action(
|
||||
"passphrase_host1_hidden",
|
||||
TR.passphrase__wallet,
|
||||
description=TR.passphrase__from_host_not_shown,
|
||||
prompt_screen=True,
|
||||
prompt_title=TR.passphrase__access_wallet,
|
||||
)
|
||||
else:
|
||||
await confirm_action(
|
||||
"passphrase_host1",
|
||||
TR.passphrase__wallet,
|
||||
description=TR.passphrase__next_screen_will_show_passphrase,
|
||||
verb=TR.buttons__continue,
|
||||
)
|
||||
|
||||
await confirm_blob(
|
||||
"passphrase_host2",
|
||||
TR.passphrase__title_confirm,
|
||||
passphrase,
|
||||
)
|
||||
|
||||
|
||||
if not utils.USE_THP:
|
||||
|
||||
async def get() -> str:
|
||||
from trezor import workflow
|
||||
|
||||
if not is_enabled():
|
||||
return ""
|
||||
else:
|
||||
workflow.close_others() # request exclusive UI access
|
||||
if storage_device.get_passphrase_always_on_device():
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
|
||||
else:
|
||||
passphrase = await _request_on_host()
|
||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||
raise DataError(
|
||||
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
|
||||
)
|
||||
|
||||
return passphrase
|
||||
|
||||
async def _request_on_host() -> str:
|
||||
from trezor.messages import PassphraseAck, PassphraseRequest
|
||||
from trezor.ui.layouts import request_passphrase_on_host
|
||||
from trezor.wire.context import call
|
||||
|
||||
request_passphrase_on_host()
|
||||
|
||||
request = PassphraseRequest()
|
||||
ack = await call(request, PassphraseAck)
|
||||
passphrase = ack.passphrase # local_cache_attribute
|
||||
|
||||
if ack.on_device:
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
if passphrase is not None:
|
||||
raise DataError("Passphrase provided when it should not be")
|
||||
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
|
||||
|
||||
if passphrase is None:
|
||||
raise DataError(
|
||||
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
|
||||
)
|
||||
|
||||
# non-empty passphrase
|
||||
if passphrase:
|
||||
await _handle_displaying_passphrase_from_host(passphrase)
|
||||
|
||||
return passphrase
|
||||
|
@ -1,3 +1,6 @@
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.protocol_common import Context
|
||||
|
||||
if not __debug__:
|
||||
from trezor.utils import halt
|
||||
|
||||
@ -70,9 +73,7 @@ if __debug__:
|
||||
"layout deadlock detected (did you send a ButtonAck?)"
|
||||
)
|
||||
|
||||
async def return_layout_change(
|
||||
ctx: wire.context.Context, detect_deadlock: bool = False
|
||||
) -> None:
|
||||
async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None:
|
||||
# set up the wait
|
||||
storage.layout_watcher = True
|
||||
|
||||
@ -244,7 +245,11 @@ if __debug__:
|
||||
# If no exception was raised, the layout did not shut down. That means that it
|
||||
# just updated itself. The update is already live for the caller to retrieve.
|
||||
|
||||
def _state() -> DebugLinkState:
|
||||
def _state(
|
||||
thp_pairing_code_entry_code: int | None = None,
|
||||
thp_pairing_code_qr_code: bytes | None = None,
|
||||
thp_pairing_code_nfc_unidirectional: bytes | None = None,
|
||||
) -> DebugLinkState:
|
||||
from trezor.messages import DebugLinkState
|
||||
|
||||
from apps.common import mnemonic, passphrase
|
||||
@ -263,13 +268,45 @@ if __debug__:
|
||||
passphrase_protection=passphrase.is_enabled(),
|
||||
reset_entropy=storage.reset_internal_entropy,
|
||||
tokens=tokens,
|
||||
thp_pairing_code_entry_code=thp_pairing_code_entry_code,
|
||||
thp_pairing_code_qr_code=thp_pairing_code_qr_code,
|
||||
thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional,
|
||||
)
|
||||
|
||||
async def dispatch_DebugLinkGetState(
|
||||
msg: DebugLinkGetState,
|
||||
) -> DebugLinkState | None:
|
||||
|
||||
thp_pairing_code_entry_code: int | None = None
|
||||
thp_pairing_code_qr_code: bytes | None = None
|
||||
thp_pairing_code_nfc_unidirectional: bytes | None = None
|
||||
if utils.USE_THP and msg.thp_channel_id is not None:
|
||||
channel_id = int.from_bytes(msg.thp_channel_id, "big")
|
||||
|
||||
from trezor.wire.thp.channel import Channel
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
from trezor.wire.thp.thp_main import _CHANNELS
|
||||
|
||||
channel: Channel | None = None
|
||||
ctx: PairingContext | None = None
|
||||
try:
|
||||
channel = _CHANNELS[channel_id]
|
||||
ctx = channel.connection_context
|
||||
except KeyError:
|
||||
pass
|
||||
if ctx is not None and isinstance(ctx, PairingContext):
|
||||
thp_pairing_code_entry_code = ctx.display_data.code_code_entry
|
||||
thp_pairing_code_qr_code = ctx.display_data.code_qr_code
|
||||
thp_pairing_code_nfc_unidirectional = (
|
||||
ctx.display_data.code_nfc_unidirectional
|
||||
)
|
||||
|
||||
if msg.wait_layout == DebugWaitType.IMMEDIATE:
|
||||
return _state()
|
||||
return _state(
|
||||
thp_pairing_code_entry_code,
|
||||
thp_pairing_code_qr_code,
|
||||
thp_pairing_code_nfc_unidirectional,
|
||||
)
|
||||
|
||||
assert DEBUG_CONTEXT is not None
|
||||
if msg.wait_layout == DebugWaitType.NEXT_LAYOUT:
|
||||
@ -280,7 +317,11 @@ if __debug__:
|
||||
if not layout_is_ready():
|
||||
return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True)
|
||||
else:
|
||||
return _state()
|
||||
return _state(
|
||||
thp_pairing_code_entry_code,
|
||||
thp_pairing_code_qr_code,
|
||||
thp_pairing_code_nfc_unidirectional,
|
||||
)
|
||||
|
||||
async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
|
||||
if msg.target_directory:
|
||||
@ -356,11 +397,12 @@ if __debug__:
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
from trezor import protobuf, wire
|
||||
from trezor.wire import codec_v1, context
|
||||
from trezor.wire.codec import codec_v1
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
|
||||
global DEBUG_CONTEXT
|
||||
|
||||
DEBUG_CONTEXT = ctx = context.Context(iface, WIRE_BUFFER_DEBUG)
|
||||
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
|
||||
|
||||
if storage.layout_watcher:
|
||||
try:
|
||||
@ -391,7 +433,7 @@ if __debug__:
|
||||
)
|
||||
|
||||
if msg.type not in WORKFLOW_HANDLERS:
|
||||
await ctx.write(wire.unexpected_message())
|
||||
await ctx.write(message_handler.unexpected_message())
|
||||
continue
|
||||
|
||||
elif req_type is None:
|
||||
@ -402,7 +444,7 @@ if __debug__:
|
||||
await ctx.write(Success())
|
||||
continue
|
||||
|
||||
req_msg = wire.wrap_protobuf_load(msg.data, req_type)
|
||||
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
|
||||
try:
|
||||
res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg)
|
||||
except Exception as exc:
|
||||
|
@ -5,10 +5,11 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
async def get_nonce(msg: GetNonce) -> Nonce:
|
||||
from storage import cache
|
||||
from storage.cache_common import APP_COMMON_NONCE
|
||||
from trezor.crypto import random
|
||||
from trezor.messages import Nonce
|
||||
from trezor.wire.context import cache_set
|
||||
|
||||
nonce = random.bytes(32)
|
||||
cache.set(cache.APP_COMMON_NONCE, nonce)
|
||||
cache_set(APP_COMMON_NONCE, nonce)
|
||||
return Nonce(nonce=nonce)
|
||||
|
@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
|
||||
boot_args = None
|
||||
|
||||
ctx = get_context()
|
||||
await ctx.write(Success(message="Rebooting"))
|
||||
await ctx.write_force(Success(message="Rebooting"))
|
||||
# make sure the outgoing USB buffer is flushed
|
||||
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
|
||||
# reboot to the bootloader, pass the firmware header hash if any
|
||||
|
@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
|
||||
from trezor import TR, config, wire, workflow
|
||||
from trezor.enums import BackupType, ButtonRequestType
|
||||
from trezor.ui.layouts import confirm_action, confirm_reset_device
|
||||
from trezor.wire.context import try_get_ctx_ids
|
||||
|
||||
from apps.common import mnemonic
|
||||
from apps.common.request_pin import (
|
||||
@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
|
||||
if recovery_type == RecoveryType.NormalRecovery:
|
||||
await confirm_reset_device(TR.recovery__title_recover, recovery=True)
|
||||
|
||||
# wipe storage to make sure the device is in a clear state
|
||||
storage.reset()
|
||||
# wipe storage to make sure the device is in a clear state (except protocol cache)
|
||||
storage.reset(excluded=try_get_ctx_ids())
|
||||
|
||||
# set up pin if requested
|
||||
if msg.pin_protection:
|
||||
|
@ -3,8 +3,9 @@ from typing import TYPE_CHECKING
|
||||
import storage.device as storage_device
|
||||
import storage.recovery as storage_recovery
|
||||
import storage.recovery_shares as storage_recovery_shares
|
||||
from trezor import TR, wire
|
||||
from trezor import TR, utils, wire
|
||||
from trezor.messages import Success
|
||||
from trezor.wire import message_handler
|
||||
|
||||
from apps.common import backup_types
|
||||
|
||||
@ -38,18 +39,26 @@ async def recovery_process() -> Success:
|
||||
|
||||
recovery_type = storage_recovery.get_type()
|
||||
|
||||
wire.AVOID_RESTARTING_FOR = (
|
||||
MessageType.Initialize,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
if utils.USE_THP:
|
||||
message_handler.AVOID_RESTARTING_FOR = (
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
else:
|
||||
message_handler.AVOID_RESTARTING_FOR = (
|
||||
MessageType.Initialize,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
try:
|
||||
return await _continue_recovery_process()
|
||||
except recover.RecoveryAborted:
|
||||
storage_recovery.end_progress()
|
||||
backup.deactivate_repeated_backup()
|
||||
if recovery_type == RecoveryType.NormalRecovery:
|
||||
storage.wipe()
|
||||
from trezor.wire.context import try_get_ctx_ids
|
||||
|
||||
storage.wipe(excluded=try_get_ctx_ids())
|
||||
raise wire.ActionCancelled
|
||||
|
||||
|
||||
@ -59,11 +68,17 @@ async def _continue_repeated_backup() -> None:
|
||||
from apps.common import backup
|
||||
from apps.management.backup_device import perform_backup
|
||||
|
||||
wire.AVOID_RESTARTING_FOR = (
|
||||
MessageType.Initialize,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
if utils.USE_THP:
|
||||
message_handler.AVOID_RESTARTING_FOR = (
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
else:
|
||||
message_handler.AVOID_RESTARTING_FOR = (
|
||||
MessageType.Initialize,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.EndSession,
|
||||
)
|
||||
|
||||
try:
|
||||
await perform_backup(is_repeated_backup=True)
|
||||
|
@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success:
|
||||
prompt_backup,
|
||||
show_wallet_created_success,
|
||||
)
|
||||
from trezor.wire.context import call
|
||||
from trezor.wire.context import call, try_get_ctx_ids
|
||||
|
||||
from apps.common.request_pin import request_pin_confirm
|
||||
|
||||
@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success:
|
||||
# Rendering empty loader so users do not feel a freezing screen
|
||||
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
|
||||
|
||||
# wipe storage to make sure the device is in a clear state
|
||||
storage.reset()
|
||||
# wipe storage to make sure the device is in a clear state (except protocol cache)
|
||||
storage.reset(excluded=try_get_ctx_ids())
|
||||
|
||||
# request and set new PIN
|
||||
if msg.pin_protection:
|
||||
@ -121,7 +121,7 @@ async def reset_device(msg: ResetDevice) -> Success:
|
||||
if perform_backup:
|
||||
await layout.show_backup_success()
|
||||
|
||||
return Success(message="Initialized")
|
||||
return Success(message="Initialized") # TODO: Why "Initialized?"
|
||||
|
||||
|
||||
async def _backup_bip39(mnemonic: str) -> None:
|
||||
|
@ -1,12 +1,19 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.wire.context import get_context, try_get_ctx_ids
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import Success, WipeDevice
|
||||
from typing import NoReturn
|
||||
|
||||
from trezor.messages import WipeDevice
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
|
||||
async def wipe_device(msg: WipeDevice) -> Success:
|
||||
async def wipe_device(msg: WipeDevice) -> NoReturn:
|
||||
import storage
|
||||
from trezor import TR, config, translations
|
||||
from trezor import TR, config, loop, translations
|
||||
from trezor.enums import ButtonRequestType
|
||||
from trezor.messages import Success
|
||||
from trezor.pin import render_empty_loader
|
||||
@ -26,16 +33,22 @@ async def wipe_device(msg: WipeDevice) -> Success:
|
||||
br_code=ButtonRequestType.WipeDevice,
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "Device wipe - start")
|
||||
|
||||
# start an empty progress screen so that the screen is not blank while waiting
|
||||
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
|
||||
|
||||
# wipe storage
|
||||
storage.wipe()
|
||||
storage.wipe(excluded=try_get_ctx_ids())
|
||||
# erase translations
|
||||
translations.deinit()
|
||||
translations.erase()
|
||||
|
||||
await get_context().write_force(Success(message="Device wiped"))
|
||||
storage.wipe_cache()
|
||||
|
||||
# reload settings
|
||||
reload_settings_from_storage()
|
||||
|
||||
return Success(message="Device wiped")
|
||||
loop.clear()
|
||||
if __debug__:
|
||||
log.debug(__name__, "Device wipe - finished")
|
||||
|
49
core/src/apps/thp/create_new_session.py
Normal file
49
core/src/apps/thp/create_new_session.py
Normal file
@ -0,0 +1,49 @@
|
||||
from trezor import log, loop
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
|
||||
from trezor.wire.context import get_context
|
||||
from trezor.wire.errors import ActionCancelled, DataError
|
||||
from trezor.wire.thp import SessionState
|
||||
|
||||
|
||||
async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
|
||||
from trezor.wire.thp.session_manager import create_new_session
|
||||
|
||||
from apps.common.seed import derive_and_store_roots
|
||||
|
||||
ctx = get_context()
|
||||
|
||||
# Assert that context `ctx` is ManagementSessionContext
|
||||
from trezor.wire.thp.session_context import ManagementSessionContext
|
||||
|
||||
assert isinstance(ctx, ManagementSessionContext)
|
||||
|
||||
channel = ctx.channel
|
||||
|
||||
# Do not use `ctx` beyond this point, as it is techically
|
||||
# allowed to change in between await statements
|
||||
|
||||
new_session = create_new_session(channel)
|
||||
try:
|
||||
await derive_and_store_roots(new_session, message)
|
||||
except DataError as e:
|
||||
return Failure(code=FailureType.DataError, message=e.message)
|
||||
except ActionCancelled as e:
|
||||
return Failure(code=FailureType.ActionCancelled, message=e.message)
|
||||
# TODO handle other errors
|
||||
|
||||
new_session.set_session_state(SessionState.ALLOCATED)
|
||||
channel.sessions[new_session.session_id] = new_session
|
||||
loop.schedule(new_session.handle())
|
||||
new_session_id: int = new_session.session_id
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
|
||||
message.passphrase if message.passphrase is not None else "",
|
||||
new_session.session_id,
|
||||
str(channel.sessions),
|
||||
)
|
||||
|
||||
return ThpNewSession(new_session_id=new_session_id)
|
@ -7,7 +7,7 @@ from trezor.messages import (
|
||||
ThpCredentialMetadata,
|
||||
ThpPairingCredential,
|
||||
)
|
||||
from trezor.wire import wrap_protobuf_load
|
||||
from trezor.wire import message_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.paths import Slip21Path
|
||||
@ -72,7 +72,9 @@ def validate_credential(
|
||||
"""
|
||||
cred_auth_key = derive_cred_auth_key()
|
||||
expected_type = protobuf.type_for_name("ThpPairingCredential")
|
||||
credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type)
|
||||
credential = message_handler.wrap_protobuf_load(
|
||||
encoded_pairing_credential_message, expected_type
|
||||
)
|
||||
assert ThpPairingCredential.is_type_of(credential)
|
||||
proto_msg = ThpAuthenticatedCredentialData(
|
||||
host_static_pubkey=host_static_pubkey,
|
||||
|
403
core/src/apps/thp/pairing.py
Normal file
403
core/src/apps/thp/pairing.py
Normal file
@ -0,0 +1,403 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
from trezor import loop, protobuf
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.enums import ThpMessageType, ThpPairingMethod
|
||||
from trezor.messages import (
|
||||
Cancel,
|
||||
ThpCodeEntryChallenge,
|
||||
ThpCodeEntryCommitment,
|
||||
ThpCodeEntryCpaceHost,
|
||||
ThpCodeEntryCpaceTrezor,
|
||||
ThpCodeEntrySecret,
|
||||
ThpCodeEntryTag,
|
||||
ThpCredentialMetadata,
|
||||
ThpCredentialRequest,
|
||||
ThpCredentialResponse,
|
||||
ThpEndRequest,
|
||||
ThpEndResponse,
|
||||
ThpNfcUnidirectionalSecret,
|
||||
ThpNfcUnidirectionalTag,
|
||||
ThpPairingPreparationsFinished,
|
||||
ThpQrCodeSecret,
|
||||
ThpQrCodeTag,
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
|
||||
from trezor.wire.thp import ChannelState, ThpError, crypto
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
|
||||
from .credential_manager import issue_credential
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
|
||||
|
||||
P = ParamSpec("P")
|
||||
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
|
||||
|
||||
#
|
||||
# Helpers - decorators
|
||||
|
||||
|
||||
def check_state_and_log(
|
||||
*allowed_states: ChannelState,
|
||||
) -> Callable[[FuncWithContext], FuncWithContext]:
|
||||
def decorator(f: FuncWithContext) -> FuncWithContext:
|
||||
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
|
||||
_check_state(context, *allowed_states)
|
||||
if __debug__:
|
||||
try:
|
||||
log.debug(__name__, "started %s", f.__name__)
|
||||
except AttributeError:
|
||||
log.debug(
|
||||
__name__,
|
||||
"started a function that cannot be named, because it raises AttributeError, eg. closure",
|
||||
)
|
||||
return f(context, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_method_is_allowed(
|
||||
pairing_method: ThpPairingMethod,
|
||||
) -> Callable[[FuncWithContext], FuncWithContext]:
|
||||
def decorator(f: FuncWithContext) -> FuncWithContext:
|
||||
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
|
||||
_check_method_is_allowed(context, pairing_method)
|
||||
return f(context, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
#
|
||||
# Pairing handlers
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP1)
|
||||
async def handle_pairing_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
|
||||
if not ThpStartPairingRequest.is_type_of(message):
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
ctx.host_name = message.host_name or ""
|
||||
|
||||
skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
|
||||
if skip_pairing:
|
||||
return await _end_pairing(ctx)
|
||||
|
||||
await _prepare_pairing(ctx)
|
||||
await ctx.write(ThpPairingPreparationsFinished())
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
|
||||
response = await show_display_data(
|
||||
ctx, _get_possible_pairing_methods_and_cancel(ctx)
|
||||
)
|
||||
|
||||
if Cancel.is_type_of(response):
|
||||
ctx.channel_ctx.clear()
|
||||
raise SilentError("Action was cancelled by the Host")
|
||||
# TODO disable NFC (if enabled)
|
||||
response = await _handle_different_pairing_methods(ctx, response)
|
||||
|
||||
while ThpCredentialRequest.is_type_of(response):
|
||||
response = await _handle_credential_request(ctx, response)
|
||||
|
||||
return await _handle_end_request(ctx, response)
|
||||
|
||||
|
||||
async def _prepare_pairing(ctx: PairingContext) -> None:
|
||||
|
||||
if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
|
||||
await _handle_code_entry_is_included(ctx)
|
||||
|
||||
if _is_method_included(ctx, ThpPairingMethod.QrCode):
|
||||
_handle_qr_code_is_included(ctx)
|
||||
|
||||
if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
|
||||
_handle_nfc_unidirectional_is_included(ctx)
|
||||
|
||||
|
||||
async def show_display_data(
|
||||
ctx: PairingContext, expected_types: Container[int] = ()
|
||||
) -> type[protobuf.MessageType]:
|
||||
from trezorui2 import CANCELLED
|
||||
|
||||
read_task = ctx.read(expected_types)
|
||||
cancel_task = ctx.display_data.get_display_layout()
|
||||
race = loop.race(read_task, cancel_task.get_result())
|
||||
result: type[protobuf.MessageType] = await race
|
||||
|
||||
if result is CANCELLED:
|
||||
raise ActionCancelled
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP1)
|
||||
async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
|
||||
commitment = sha256(ctx.secret).digest()
|
||||
|
||||
challenge_message = await ctx.call( # noqa: F841
|
||||
ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
|
||||
)
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
|
||||
|
||||
if not ThpCodeEntryChallenge.is_type_of(challenge_message):
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
if challenge_message.challenge is None:
|
||||
raise Exception("Invalid message")
|
||||
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
|
||||
sha_ctx.update(ctx.secret)
|
||||
sha_ctx.update(challenge_message.challenge)
|
||||
sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8"))
|
||||
code_code_entry_hash = sha_ctx.digest()
|
||||
ctx.display_data.code_code_entry = (
|
||||
int.from_bytes(code_code_entry_hash, "big") % 1000000
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
|
||||
def _handle_qr_code_is_included(ctx: PairingContext) -> None:
|
||||
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
|
||||
sha_ctx.update(ctx.secret)
|
||||
sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8"))
|
||||
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
|
||||
def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
|
||||
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
|
||||
sha_ctx.update(ctx.secret)
|
||||
sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8"))
|
||||
ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16]
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP3)
|
||||
async def _handle_different_pairing_methods(
|
||||
ctx: PairingContext, response: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
if ThpCodeEntryCpaceHost.is_type_of(response):
|
||||
return await _handle_code_entry_cpace(ctx, response)
|
||||
if ThpQrCodeTag.is_type_of(response):
|
||||
return await _handle_qr_code_tag(ctx, response)
|
||||
if ThpNfcUnidirectionalTag.is_type_of(response):
|
||||
return await _handle_nfc_unidirectional_tag(ctx, response)
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP3)
|
||||
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
|
||||
async def _handle_code_entry_cpace(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
from trezor.wire.thp.cpace import Cpace
|
||||
|
||||
# TODO check that ThpCodeEntryCpaceHost message is valid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpCodeEntryCpaceHost)
|
||||
if message.cpace_host_public_key is None:
|
||||
raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
|
||||
|
||||
ctx.cpace = Cpace(
|
||||
message.cpace_host_public_key,
|
||||
ctx.channel_ctx.get_handshake_hash(),
|
||||
)
|
||||
assert ctx.display_data.code_code_entry is not None
|
||||
ctx.cpace.generate_keys_and_secret(
|
||||
ctx.display_data.code_code_entry.to_bytes(6, "big")
|
||||
)
|
||||
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
|
||||
response = await ctx.call(
|
||||
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
|
||||
ThpCodeEntryTag,
|
||||
)
|
||||
return await _handle_code_entry_tag(ctx, response)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP4)
|
||||
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
|
||||
async def _handle_code_entry_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpCodeEntryTag)
|
||||
|
||||
expected_tag = sha256(ctx.cpace.shared_secret).digest()
|
||||
if expected_tag != message.tag:
|
||||
print(
|
||||
"expected code entry tag:", hexlify(expected_tag).decode()
|
||||
) # TODO remove after testing
|
||||
print(
|
||||
"expected code entry shared secret:",
|
||||
hexlify(ctx.cpace.shared_secret).decode(),
|
||||
) # TODO remove after testing
|
||||
raise ThpError("Unexpected Code Entry Tag")
|
||||
|
||||
return await _handle_secret_reveal(
|
||||
ctx,
|
||||
msg=ThpCodeEntrySecret(secret=ctx.secret),
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP3)
|
||||
@check_method_is_allowed(ThpPairingMethod.QrCode)
|
||||
async def _handle_qr_code_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpQrCodeTag)
|
||||
assert ctx.display_data.code_qr_code is not None
|
||||
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
|
||||
if expected_tag != message.tag:
|
||||
print(
|
||||
"expected qr code tag:", hexlify(expected_tag).decode()
|
||||
) # TODO remove after testing
|
||||
print(
|
||||
"expected code qr code tag:",
|
||||
hexlify(ctx.display_data.code_qr_code).decode(),
|
||||
) # TODO remove after testing
|
||||
print(
|
||||
"expected secret:", hexlify(ctx.secret).decode()
|
||||
) # TODO remove after testing
|
||||
raise ThpError("Unexpected QR Code Tag")
|
||||
|
||||
return await _handle_secret_reveal(
|
||||
ctx,
|
||||
msg=ThpQrCodeSecret(secret=ctx.secret),
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP3)
|
||||
@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
|
||||
async def _handle_nfc_unidirectional_tag(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpNfcUnidirectionalTag)
|
||||
|
||||
expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
|
||||
if expected_tag != message.tag:
|
||||
print(
|
||||
"expected nfc tag:", hexlify(expected_tag).decode()
|
||||
) # TODO remove after testing
|
||||
raise ThpError("Unexpected NFC Unidirectional Tag")
|
||||
|
||||
return await _handle_secret_reveal(
|
||||
ctx,
|
||||
msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
|
||||
async def _handle_secret_reveal(
|
||||
ctx: PairingContext,
|
||||
msg: protobuf.MessageType,
|
||||
) -> protobuf.MessageType:
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
|
||||
return await ctx.call_any(
|
||||
msg,
|
||||
ThpMessageType.ThpCredentialRequest,
|
||||
ThpMessageType.ThpEndRequest,
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TC1)
|
||||
async def _handle_credential_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> protobuf.MessageType:
|
||||
ctx.secret
|
||||
|
||||
if not ThpCredentialRequest.is_type_of(message):
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
if message.host_static_pubkey is None:
|
||||
raise Exception("Invalid message") # TODO change failure type
|
||||
|
||||
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
|
||||
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
|
||||
credential = issue_credential(message.host_static_pubkey, credential_metadata)
|
||||
|
||||
return await ctx.call_any(
|
||||
ThpCredentialResponse(
|
||||
trezor_static_pubkey=trezor_static_pubkey, credential=credential
|
||||
),
|
||||
ThpMessageType.ThpCredentialRequest,
|
||||
ThpMessageType.ThpEndRequest,
|
||||
)
|
||||
|
||||
|
||||
@check_state_and_log(ChannelState.TC1)
|
||||
async def _handle_end_request(
|
||||
ctx: PairingContext, message: protobuf.MessageType
|
||||
) -> ThpEndResponse:
|
||||
if not ThpEndRequest.is_type_of(message):
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
return await _end_pairing(ctx)
|
||||
|
||||
|
||||
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
return ThpEndResponse()
|
||||
|
||||
|
||||
#
|
||||
# Helpers - checkers
|
||||
|
||||
|
||||
def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
|
||||
if ctx.channel_ctx.get_channel_state() not in allowed_states:
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
|
||||
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
|
||||
if not _is_method_included(ctx, method):
|
||||
raise ThpError("Unexpected pairing method")
|
||||
|
||||
|
||||
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
|
||||
return method in ctx.channel_ctx.selected_pairing_methods
|
||||
|
||||
|
||||
#
|
||||
# Helpers - getters
|
||||
|
||||
|
||||
def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]:
|
||||
r = _get_possible_pairing_methods(ctx)
|
||||
mtype = Cancel.MESSAGE_WIRE_TYPE
|
||||
return r + ((mtype,) if mtype is not None else ())
|
||||
|
||||
|
||||
def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
|
||||
r = tuple(
|
||||
_get_message_type_for_method(method)
|
||||
for method in ctx.channel_ctx.selected_pairing_methods
|
||||
)
|
||||
if __debug__:
|
||||
from trezor.messages import DebugLinkGetState
|
||||
|
||||
mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE
|
||||
return r + ((mtype,) if mtype is not None else ())
|
||||
return r
|
||||
|
||||
|
||||
def _get_message_type_for_method(method: int) -> int:
|
||||
if method is ThpPairingMethod.CodeEntry:
|
||||
return ThpMessageType.ThpCodeEntryCpaceHost
|
||||
if method is ThpPairingMethod.NFC_Unidirectional:
|
||||
return ThpMessageType.ThpNfcUnidirectionalTag
|
||||
if method is ThpPairingMethod.QrCode:
|
||||
return ThpMessageType.ThpQrCodeTag
|
||||
raise ValueError("Unexpected pairing method - no message type available")
|
@ -1,8 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
from trezor.wire import Handler, Msg
|
||||
|
||||
|
||||
@ -37,6 +35,13 @@ def _find_message_handler_module(msg_type: int) -> str:
|
||||
if __debug__ and msg_type == MessageType.BenchmarkRun:
|
||||
return "apps.benchmark.run"
|
||||
|
||||
if utils.USE_THP:
|
||||
from trezor.enums import ThpMessageType
|
||||
|
||||
# thp management
|
||||
if msg_type == ThpMessageType.ThpCreateNewSession:
|
||||
return "apps.thp.create_new_session"
|
||||
|
||||
# management
|
||||
if msg_type == MessageType.ResetDevice:
|
||||
return "apps.management.reset_device"
|
||||
@ -215,7 +220,7 @@ def _find_message_handler_module(msg_type: int) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
def find_registered_handler(msg_type: int) -> Handler | None:
|
||||
if msg_type in workflow_handlers:
|
||||
# Message has a handler available, return it directly.
|
||||
return workflow_handlers[msg_type]
|
||||
|
@ -29,7 +29,7 @@ if __debug__:
|
||||
|
||||
# trezor.pin imports trezor.utils
|
||||
# We need it as an always-active module because trezor.pin.show_pin_timeout is used
|
||||
# as an UI callback for storage, which can be invoked at any time
|
||||
# as a UI callback for storage, which can be invoked at any time
|
||||
import trezor.pin # noqa: F401
|
||||
|
||||
# === Prepare the USB interfaces first. Do not connect to the host yet.
|
||||
|
@ -1,11 +1,27 @@
|
||||
# make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache, common, device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
|
||||
def wipe() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def wipe(excluded: Tuple[bytes, bytes] | None) -> None:
|
||||
"""
|
||||
TODO REPHRASE SO THAT IT IS TRUE! Wipes the storage. Using `exclude_protocol=False` destroys the THP communication channel.
|
||||
If the device should communicate after wipe, use `exclude_protocol=True` and clear cache manually later using
|
||||
`wipe_cache()`.
|
||||
"""
|
||||
from trezor import config
|
||||
|
||||
config.wipe()
|
||||
cache.clear_all(excluded)
|
||||
|
||||
|
||||
def wipe_cache() -> None:
|
||||
cache.clear_all()
|
||||
|
||||
|
||||
@ -21,12 +37,12 @@ def init_unlocked() -> None:
|
||||
common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True)
|
||||
|
||||
|
||||
def reset() -> None:
|
||||
def reset(excluded: Tuple[bytes, bytes] | None) -> None:
|
||||
"""
|
||||
Wipes storage but keeps the device id unchanged.
|
||||
"""
|
||||
device_id = device.get_device_id()
|
||||
wipe()
|
||||
wipe(excluded)
|
||||
common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True)
|
||||
|
||||
|
||||
|
361
core/src/storage/cache_thp.py
Normal file
361
core/src/storage/cache_thp.py
Normal file
@ -0,0 +1,361 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
|
||||
pass
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
pass
|
||||
|
||||
# THP specific constants
|
||||
_MAX_CHANNELS_COUNT = const(10)
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
|
||||
|
||||
_CHANNEL_STATE_LENGTH = const(1)
|
||||
_WIRE_INTERFACE_LENGTH = const(1)
|
||||
_SESSION_STATE_LENGTH = const(1)
|
||||
_CHANNEL_ID_LENGTH = const(2)
|
||||
SESSION_ID_LENGTH = const(1)
|
||||
BROADCAST_CHANNEL_ID = const(0xFFFF)
|
||||
KEY_LENGTH = const(32)
|
||||
TAG_LENGTH = const(16)
|
||||
_UNALLOCATED_STATE = const(0)
|
||||
_MANAGEMENT_STATE = const(2)
|
||||
MANAGEMENT_SESSION_ID = const(0)
|
||||
|
||||
|
||||
class ThpDataCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.channel_id[:] = b""
|
||||
self.last_usage = 0
|
||||
super().clear()
|
||||
|
||||
|
||||
class ChannelCache(ThpDataCache):
|
||||
def __init__(self) -> None:
|
||||
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
|
||||
self.state = bytearray(_CHANNEL_STATE_LENGTH)
|
||||
self.iface = bytearray(1) # TODO add decoding
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.session_id_counter = 0x00
|
||||
self.fields = (
|
||||
32, # CHANNEL_HANDSHAKE_HASH
|
||||
32, # CHANNEL_KEY_RECEIVE
|
||||
32, # CHANNEL_KEY_SEND
|
||||
8, # CHANNEL_NONCE_RECEIVE
|
||||
8, # CHANNEL_NONCE_SEND
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.state[:] = bytearray(
|
||||
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
|
||||
) # Set state to UNALLOCATED
|
||||
self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH)
|
||||
self.state[:] = bytearray(_CHANNEL_STATE_LENGTH)
|
||||
self.iface[:] = bytearray(1)
|
||||
super().clear()
|
||||
|
||||
|
||||
class SessionThpCache(ThpDataCache):
|
||||
def __init__(self) -> None:
|
||||
from trezor import utils
|
||||
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
0, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
0, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
|
||||
self.session_id[:] = b""
|
||||
super().clear()
|
||||
|
||||
|
||||
_CHANNELS: list[ChannelCache] = []
|
||||
_SESSIONS: list[SessionThpCache] = []
|
||||
cid_counter: int = 0
|
||||
|
||||
# Last-used counter
|
||||
_usage_counter = 0
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _CHANNELS
|
||||
global _SESSIONS
|
||||
global cid_counter
|
||||
|
||||
for _ in range(_MAX_CHANNELS_COUNT):
|
||||
_CHANNELS.append(ChannelCache())
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for channel in _CHANNELS:
|
||||
channel.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
from trezorcrypto import random
|
||||
|
||||
cid_counter = random.uniform(0xFFFE)
|
||||
|
||||
|
||||
def get_new_channel(iface: bytes) -> ChannelCache:
|
||||
if len(iface) != _WIRE_INTERFACE_LENGTH:
|
||||
raise Exception("Invalid WireInterface (encoded) length")
|
||||
|
||||
new_cid = get_next_channel_id()
|
||||
index = _get_next_channel_index()
|
||||
|
||||
# clear sessions from replaced channel
|
||||
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
|
||||
old_cid = _CHANNELS[index].channel_id
|
||||
clear_sessions_with_channel_id(old_cid)
|
||||
|
||||
_CHANNELS[index] = ChannelCache()
|
||||
_CHANNELS[index].channel_id[:] = new_cid
|
||||
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
|
||||
_CHANNELS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||
)
|
||||
_CHANNELS[index].iface[:] = bytearray(iface)
|
||||
return _CHANNELS[index]
|
||||
|
||||
|
||||
def update_channel_last_used(channel_id: bytes) -> None:
|
||||
for channel in _CHANNELS:
|
||||
if channel.channel_id == channel_id:
|
||||
channel.last_usage = _get_usage_counter_and_increment()
|
||||
return
|
||||
|
||||
|
||||
def update_session_last_used(channel_id: bytes, session_id: bytes) -> None:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == channel_id and session.session_id == session_id:
|
||||
session.last_usage = _get_usage_counter_and_increment()
|
||||
update_channel_last_used(channel_id)
|
||||
return
|
||||
|
||||
|
||||
def get_all_allocated_channels() -> list[ChannelCache]:
|
||||
_list: list[ChannelCache] = []
|
||||
for channel in _CHANNELS:
|
||||
if _get_channel_state(channel) != _UNALLOCATED_STATE:
|
||||
_list.append(channel)
|
||||
return _list
|
||||
|
||||
|
||||
def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]:
|
||||
if __debug__:
|
||||
from trezor.utils import get_bytes_as_str
|
||||
_list: list[SessionThpCache] = []
|
||||
for session in _SESSIONS:
|
||||
if _get_session_state(session) == _UNALLOCATED_STATE:
|
||||
continue
|
||||
if session.channel_id != channel_id:
|
||||
continue
|
||||
_list.append(session)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
|
||||
get_bytes_as_str(session.channel_id),
|
||||
get_bytes_as_str(session.session_id),
|
||||
)
|
||||
|
||||
return _list
|
||||
|
||||
|
||||
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
|
||||
if len(key) != KEY_LENGTH:
|
||||
raise Exception("Invalid key length")
|
||||
channel.host_ephemeral_pubkey = key
|
||||
|
||||
|
||||
def get_new_session(channel: ChannelCache) -> SessionThpCache:
|
||||
new_sid = get_next_session_id(channel)
|
||||
index = _get_next_session_index()
|
||||
|
||||
_SESSIONS[index] = SessionThpCache()
|
||||
_SESSIONS[index].channel_id[:] = channel.channel_id
|
||||
_SESSIONS[index].session_id[:] = new_sid
|
||||
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
|
||||
channel.last_usage = (
|
||||
_get_usage_counter_and_increment()
|
||||
) # increment also use of the channel so it does not get replaced
|
||||
_SESSIONS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
|
||||
)
|
||||
return _SESSIONS[index]
|
||||
|
||||
|
||||
def _get_usage_counter_and_increment() -> int:
|
||||
global _usage_counter
|
||||
_usage_counter += 1
|
||||
return _usage_counter
|
||||
|
||||
|
||||
def _get_next_channel_index() -> int:
|
||||
idx = _get_unallocated_channel_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
|
||||
|
||||
|
||||
def _get_next_session_index() -> int:
|
||||
idx = _get_unallocated_session_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT)
|
||||
|
||||
|
||||
def _get_unallocated_channel_index() -> int | None:
|
||||
for i in range(_MAX_CHANNELS_COUNT):
|
||||
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_unallocated_session_index() -> int | None:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_channel_state(channel: ChannelCache) -> int:
|
||||
return int.from_bytes(channel.state, "big")
|
||||
|
||||
|
||||
def _get_session_state(session: SessionThpCache) -> int:
|
||||
return int.from_bytes(session.state, "big")
|
||||
|
||||
|
||||
def get_next_channel_id() -> bytes:
|
||||
global cid_counter
|
||||
while True:
|
||||
cid_counter += 1
|
||||
if cid_counter >= BROADCAST_CHANNEL_ID:
|
||||
cid_counter = 1
|
||||
if _is_cid_unique():
|
||||
break
|
||||
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def get_next_session_id(channel: ChannelCache) -> bytes:
|
||||
while True:
|
||||
if channel.session_id_counter >= 255:
|
||||
channel.session_id_counter = 1
|
||||
else:
|
||||
channel.session_id_counter += 1
|
||||
if _is_session_id_unique(channel):
|
||||
break
|
||||
new_sid = channel.session_id_counter
|
||||
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def _is_session_id_unique(channel: ChannelCache) -> bool:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == channel.channel_id:
|
||||
if session.session_id == channel.session_id_counter:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
global cid_counter
|
||||
cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
|
||||
for channel in _CHANNELS:
|
||||
if channel.channel_id == cid_counter_bytes:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_least_recently_used_item(
|
||||
list: list[ChannelCache] | list[SessionThpCache], max_count: int
|
||||
) -> int:
|
||||
global _usage_counter
|
||||
lru_counter = _usage_counter + 1
|
||||
lru_item_index = 0
|
||||
for i in range(max_count):
|
||||
if list[i].last_usage < lru_counter:
|
||||
lru_counter = list[i].last_usage
|
||||
lru_item_index = i
|
||||
return lru_item_index
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_sessions_with_channel_id(channel_id: bytes) -> None:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == channel_id:
|
||||
session.clear()
|
||||
|
||||
|
||||
def clear_session(session: SessionThpCache) -> None:
|
||||
for s in _SESSIONS:
|
||||
if s.channel_id == session.channel_id and s.session_id == session.session_id:
|
||||
session.clear()
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
for channel in _CHANNELS:
|
||||
channel.clear()
|
||||
|
||||
|
||||
def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
|
||||
cid, sid = excluded
|
||||
|
||||
for channel in _CHANNELS:
|
||||
if channel.channel_id != cid:
|
||||
channel.clear()
|
||||
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id != cid and session.session_id != sid:
|
||||
session.clear()
|
||||
else:
|
||||
s_last_usage = session.last_usage
|
||||
session.clear()
|
||||
session.last_usage = s_last_usage
|
||||
session.state = bytearray(_MANAGEMENT_STATE.to_bytes(1, "big"))
|
||||
session.session_id[:] = bytearray(sid)
|
||||
session.channel_id[:] = bytearray(cid)
|
2
core/src/trezor/enums/FailureType.py
generated
2
core/src/trezor/enums/FailureType.py
generated
@ -16,4 +16,6 @@ NotInitialized = 11
|
||||
PinMismatch = 12
|
||||
WipeCodeMismatch = 13
|
||||
InvalidSession = 14
|
||||
ThpUnallocatedSession = 15
|
||||
InvalidProtocol = 16
|
||||
FirmwareError = 99
|
||||
|
22
core/src/trezor/enums/ThpMessageType.py
generated
Normal file
22
core/src/trezor/enums/ThpMessageType.py
generated
Normal file
@ -0,0 +1,22 @@
|
||||
# Automatically generated by pb2py
|
||||
# fmt: off
|
||||
# isort:skip_file
|
||||
|
||||
ThpCreateNewSession = 1000
|
||||
ThpNewSession = 1001
|
||||
ThpStartPairingRequest = 1008
|
||||
ThpPairingPreparationsFinished = 1009
|
||||
ThpCredentialRequest = 1010
|
||||
ThpCredentialResponse = 1011
|
||||
ThpEndRequest = 1012
|
||||
ThpEndResponse = 1013
|
||||
ThpCodeEntryCommitment = 1016
|
||||
ThpCodeEntryChallenge = 1017
|
||||
ThpCodeEntryCpaceHost = 1018
|
||||
ThpCodeEntryCpaceTrezor = 1019
|
||||
ThpCodeEntryTag = 1020
|
||||
ThpCodeEntrySecret = 1021
|
||||
ThpQrCodeTag = 1024
|
||||
ThpQrCodeSecret = 1025
|
||||
ThpNfcUnidirectionalTag = 1032
|
||||
ThpNfcUnidirectionalSecret = 1033
|
8
core/src/trezor/enums/ThpPairingMethod.py
generated
Normal file
8
core/src/trezor/enums/ThpPairingMethod.py
generated
Normal file
@ -0,0 +1,8 @@
|
||||
# Automatically generated by pb2py
|
||||
# fmt: off
|
||||
# isort:skip_file
|
||||
|
||||
NoMethod = 1
|
||||
CodeEntry = 2
|
||||
QrCode = 3
|
||||
NFC_Unidirectional = 4
|
28
core/src/trezor/enums/__init__.py
generated
28
core/src/trezor/enums/__init__.py
generated
@ -39,6 +39,8 @@ if TYPE_CHECKING:
|
||||
PinMismatch = 12
|
||||
WipeCodeMismatch = 13
|
||||
InvalidSession = 14
|
||||
ThpUnallocatedSession = 15
|
||||
InvalidProtocol = 16
|
||||
FirmwareError = 99
|
||||
|
||||
class ButtonRequestType(IntEnum):
|
||||
@ -343,6 +345,32 @@ if TYPE_CHECKING:
|
||||
Nay = 1
|
||||
Pass = 2
|
||||
|
||||
class ThpMessageType(IntEnum):
|
||||
ThpCreateNewSession = 1000
|
||||
ThpNewSession = 1001
|
||||
ThpStartPairingRequest = 1008
|
||||
ThpPairingPreparationsFinished = 1009
|
||||
ThpCredentialRequest = 1010
|
||||
ThpCredentialResponse = 1011
|
||||
ThpEndRequest = 1012
|
||||
ThpEndResponse = 1013
|
||||
ThpCodeEntryCommitment = 1016
|
||||
ThpCodeEntryChallenge = 1017
|
||||
ThpCodeEntryCpaceHost = 1018
|
||||
ThpCodeEntryCpaceTrezor = 1019
|
||||
ThpCodeEntryTag = 1020
|
||||
ThpCodeEntrySecret = 1021
|
||||
ThpQrCodeTag = 1024
|
||||
ThpQrCodeSecret = 1025
|
||||
ThpNfcUnidirectionalTag = 1032
|
||||
ThpNfcUnidirectionalSecret = 1033
|
||||
|
||||
class ThpPairingMethod(IntEnum):
|
||||
NoMethod = 1
|
||||
CodeEntry = 2
|
||||
QrCode = 3
|
||||
NFC_Unidirectional = 4
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Initialize = 0
|
||||
Ping = 1
|
||||
|
282
core/src/trezor/messages.py
generated
282
core/src/trezor/messages.py
generated
@ -67,6 +67,8 @@ if TYPE_CHECKING:
|
||||
from trezor.enums import StellarSignerType # noqa: F401
|
||||
from trezor.enums import TezosBallotType # noqa: F401
|
||||
from trezor.enums import TezosContractType # noqa: F401
|
||||
from trezor.enums import ThpMessageType # noqa: F401
|
||||
from trezor.enums import ThpPairingMethod # noqa: F401
|
||||
from trezor.enums import WordRequestType # noqa: F401
|
||||
|
||||
class BenchmarkListNames(protobuf.MessageType):
|
||||
@ -2863,11 +2865,13 @@ if TYPE_CHECKING:
|
||||
|
||||
class DebugLinkGetState(protobuf.MessageType):
|
||||
wait_layout: "DebugWaitType"
|
||||
thp_channel_id: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
wait_layout: "DebugWaitType | None" = None,
|
||||
thp_channel_id: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@ -2889,6 +2893,9 @@ if TYPE_CHECKING:
|
||||
reset_word_pos: "int | None"
|
||||
mnemonic_type: "BackupType | None"
|
||||
tokens: "list[str]"
|
||||
thp_pairing_code_entry_code: "int | None"
|
||||
thp_pairing_code_qr_code: "bytes | None"
|
||||
thp_pairing_code_nfc_unidirectional: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -2906,6 +2913,9 @@ if TYPE_CHECKING:
|
||||
recovery_word_pos: "int | None" = None,
|
||||
reset_word_pos: "int | None" = None,
|
||||
mnemonic_type: "BackupType | None" = None,
|
||||
thp_pairing_code_entry_code: "int | None" = None,
|
||||
thp_pairing_code_qr_code: "bytes | None" = None,
|
||||
thp_pairing_code_nfc_unidirectional: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@ -6127,6 +6137,278 @@ if TYPE_CHECKING:
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpDeviceProperties(protobuf.MessageType):
|
||||
internal_model: "str | None"
|
||||
model_variant: "int | None"
|
||||
bootloader_mode: "bool | None"
|
||||
protocol_version: "int | None"
|
||||
pairing_methods: "list[ThpPairingMethod]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pairing_methods: "list[ThpPairingMethod] | None" = None,
|
||||
internal_model: "str | None" = None,
|
||||
model_variant: "int | None" = None,
|
||||
bootloader_mode: "bool | None" = None,
|
||||
protocol_version: "int | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||
host_pairing_credential: "bytes | None"
|
||||
pairing_methods: "list[ThpPairingMethod]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pairing_methods: "list[ThpPairingMethod] | None" = None,
|
||||
host_pairing_credential: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCreateNewSession(protobuf.MessageType):
|
||||
passphrase: "str | None"
|
||||
on_device: "bool | None"
|
||||
derive_cardano: "bool | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
passphrase: "str | None" = None,
|
||||
on_device: "bool | None" = None,
|
||||
derive_cardano: "bool | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpNewSession(protobuf.MessageType):
|
||||
new_session_id: "int | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
new_session_id: "int | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpStartPairingRequest(protobuf.MessageType):
|
||||
host_name: "str | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_name: "str | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||
commitment: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
commitment: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||
challenge: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
challenge: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntryCpaceHost(protobuf.MessageType):
|
||||
cpace_host_public_key: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_host_public_key: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||
cpace_trezor_public_key: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_trezor_public_key: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntryTag(protobuf.MessageType):
|
||||
tag: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||
secret: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpQrCodeTag(protobuf.MessageType):
|
||||
tag: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpQrCodeSecret(protobuf.MessageType):
|
||||
secret: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpNfcUnidirectionalTag(protobuf.MessageType):
|
||||
tag: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
|
||||
secret: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCredentialRequest(protobuf.MessageType):
|
||||
host_static_pubkey: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_static_pubkey: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCredentialResponse(protobuf.MessageType):
|
||||
trezor_static_pubkey: "bytes | None"
|
||||
credential: "bytes | None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
trezor_static_pubkey: "bytes | None" = None,
|
||||
credential: "bytes | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpEndRequest(protobuf.MessageType):
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpEndResponse(protobuf.MessageType):
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
|
||||
return isinstance(msg, cls)
|
||||
|
||||
class ThpCredentialMetadata(protobuf.MessageType):
|
||||
host_name: "str | None"
|
||||
|
||||
|
@ -328,7 +328,7 @@ class Layout(Generic[T]):
|
||||
|
||||
def _paint(self) -> None:
|
||||
"""Paint the layout and ensure that homescreen cache is properly invalidated."""
|
||||
import storage.cache as storage_cache
|
||||
import storage.cache_common as storage_cache
|
||||
|
||||
painted = self.layout.paint()
|
||||
if painted:
|
||||
|
@ -35,6 +35,10 @@ from typing import TYPE_CHECKING
|
||||
|
||||
DISABLE_ANIMATION = 0
|
||||
|
||||
DISABLE_ENCRYPTION: bool = False
|
||||
|
||||
ALLOW_DEBUG_MESSAGES: bool = True
|
||||
|
||||
if __debug__:
|
||||
if EMULATOR:
|
||||
import uos
|
||||
@ -111,6 +115,7 @@ def presize_module(modname: str, size: int) -> None:
|
||||
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
def mem_dump(filename: str) -> None:
|
||||
from micropython import mem_info
|
||||
@ -127,6 +132,9 @@ if __debug__:
|
||||
else:
|
||||
mem_info(True)
|
||||
|
||||
def get_bytes_as_str(a: bytes) -> str:
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
def ensure(cond: bool, msg: str | None = None) -> None:
|
||||
if not cond:
|
||||
|
@ -8,6 +8,12 @@ class Error(Exception):
|
||||
self.message = message
|
||||
|
||||
|
||||
class SilentError(Exception):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__()
|
||||
self.message = message
|
||||
|
||||
|
||||
class UnexpectedMessage(Error):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(FailureType.UnexpectedMessage, message)
|
||||
|
184
core/src/trezor/wire/thp/__init__.py
Normal file
184
core/src/trezor/wire/thp/__init__.py
Normal file
@ -0,0 +1,184 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import protobuf, utils
|
||||
from trezor.enums import ThpPairingMethod
|
||||
from trezor.messages import ThpDeviceProperties
|
||||
|
||||
from ..protocol_common import WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
|
||||
from trezor.wire import WireInterface
|
||||
from typing_extensions import Self
|
||||
else:
|
||||
IntEnum = object
|
||||
|
||||
CODEC_V1 = const(0x3F)
|
||||
|
||||
HANDSHAKE_INIT_REQ = const(0x00)
|
||||
HANDSHAKE_INIT_RES = const(0x01)
|
||||
HANDSHAKE_COMP_REQ = const(0x02)
|
||||
HANDSHAKE_COMP_RES = const(0x03)
|
||||
ENCRYPTED = const(0x04)
|
||||
|
||||
ACK_MESSAGE = const(0x20)
|
||||
CHANNEL_ALLOCATION_REQ = const(0x40)
|
||||
_CHANNEL_ALLOCATION_RES = const(0x41)
|
||||
_ERROR = const(0x42)
|
||||
CONTINUATION_PACKET = const(0x80)
|
||||
|
||||
|
||||
class ThpError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class ThpDecryptionError(ThpError):
|
||||
pass
|
||||
|
||||
|
||||
class ThpInvalidDataError(ThpError):
|
||||
pass
|
||||
|
||||
|
||||
class ThpUnallocatedSessionError(ThpError):
|
||||
|
||||
def __init__(self, session_id: int) -> None:
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class ThpErrorType(IntEnum):
|
||||
TRANSPORT_BUSY = 1
|
||||
UNALLOCATED_CHANNEL = 2
|
||||
DECRYPTION_FAILED = 3
|
||||
INVALID_DATA = 4
|
||||
|
||||
|
||||
class ChannelState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
TH1 = 1
|
||||
TH2 = 2
|
||||
TP1 = 3
|
||||
TP2 = 4
|
||||
TP3 = 5
|
||||
TP4 = 6
|
||||
TC1 = 7
|
||||
ENCRYPTED_TRANSPORT = 8
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
ALLOCATED = 1
|
||||
MANAGEMENT = 2
|
||||
|
||||
|
||||
class PacketHeader:
|
||||
format_str_init = ">BHH"
|
||||
format_str_cont = ">BH"
|
||||
|
||||
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
|
||||
|
||||
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
"""
|
||||
Packs header information in the form of **intial** packet
|
||||
into the provided buffer.
|
||||
"""
|
||||
ustruct.pack_into(
|
||||
self.format_str_init,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
"""
|
||||
Packs header information in the form of **continuation** packet header
|
||||
into the provided buffer.
|
||||
"""
|
||||
ustruct.pack_into(
|
||||
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_error_header(cls, cid: int, length: int) -> Self:
|
||||
"""
|
||||
Returns header for protocol-level error messages.
|
||||
"""
|
||||
return cls(_ERROR, cid, length)
|
||||
|
||||
@classmethod
|
||||
def get_channel_allocation_response_header(cls, length: int) -> Self:
|
||||
"""
|
||||
Returns header for allocation response handshake message.
|
||||
"""
|
||||
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
|
||||
|
||||
|
||||
_DEFAULT_ENABLED_PAIRING_METHODS = [
|
||||
ThpPairingMethod.CodeEntry,
|
||||
ThpPairingMethod.QrCode,
|
||||
ThpPairingMethod.NFC_Unidirectional,
|
||||
]
|
||||
|
||||
|
||||
def get_enabled_pairing_methods(
|
||||
iface: WireInterface | None = None,
|
||||
) -> list[ThpPairingMethod]:
|
||||
"""
|
||||
Returns pairing methods that are currently allowed by the device
|
||||
with respect to the wire interface the host communicates on.
|
||||
"""
|
||||
import usb
|
||||
|
||||
methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy()
|
||||
if iface is not None and iface is usb.iface_wire:
|
||||
methods.append(ThpPairingMethod.NoMethod)
|
||||
return methods
|
||||
|
||||
|
||||
def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
|
||||
# TODO define model variants
|
||||
return ThpDeviceProperties(
|
||||
pairing_methods=get_enabled_pairing_methods(iface),
|
||||
internal_model=utils.INTERNAL_MODEL,
|
||||
model_variant=0,
|
||||
bootloader_mode=False,
|
||||
protocol_version=2,
|
||||
)
|
||||
|
||||
|
||||
def get_encoded_device_properties(iface: WireInterface) -> bytes:
|
||||
props = _get_device_properties(iface)
|
||||
length = protobuf.encoded_length(props)
|
||||
encoded_properties = bytearray(length)
|
||||
protobuf.encode(encoded_properties, props)
|
||||
return encoded_properties
|
||||
|
||||
|
||||
def get_channel_allocation_response(
|
||||
nonce: bytes, new_cid: bytes, iface: WireInterface
|
||||
) -> bytes:
|
||||
props_msg = get_encoded_device_properties(iface)
|
||||
return nonce + new_cid + props_msg
|
||||
|
||||
|
||||
if __debug__:
|
||||
|
||||
def state_to_str(state: int) -> str:
|
||||
name = {
|
||||
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
|
||||
}.get(state)
|
||||
if name is not None:
|
||||
return name
|
||||
return "UNKNOWN_STATE"
|
102
core/src/trezor/wire/thp/alternating_bit_protocol.py
Normal file
102
core/src/trezor/wire/thp/alternating_bit_protocol.py
Normal file
@ -0,0 +1,102 @@
|
||||
from storage.cache_thp import ChannelCache
|
||||
from trezor import log, utils
|
||||
from trezor.wire.thp import ThpError
|
||||
|
||||
|
||||
def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
||||
"""
|
||||
Checks if:
|
||||
- an ACK message is expected
|
||||
- the received ACK message acknowledges correct sequence number (bit)
|
||||
"""
|
||||
if not _is_ack_expected(cache):
|
||||
return False
|
||||
|
||||
if not _has_ack_correct_sync_bit(cache, ack_bit):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||
is_expected: bool = not is_sending_allowed(cache)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected:
|
||||
log.debug(__name__, "Received unexpected ACK message")
|
||||
return is_expected
|
||||
|
||||
|
||||
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||
is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct:
|
||||
log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||
return is_correct
|
||||
|
||||
|
||||
def is_sending_allowed(cache: ChannelCache) -> bool:
|
||||
"""
|
||||
Checks whether sending a message in the provided channel is allowed.
|
||||
|
||||
Note: Sending a message in a channel before receipt of ACK message for the previously
|
||||
sent message (in the channel) is prohibited, as it can lead to desynchronization.
|
||||
"""
|
||||
return bool(cache.sync >> 7)
|
||||
|
||||
|
||||
def get_send_seq_bit(cache: ChannelCache) -> int:
|
||||
"""
|
||||
Returns the sequential number (bit) of the next message to be sent
|
||||
in the provided channel.
|
||||
"""
|
||||
return (cache.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
|
||||
"""
|
||||
Returns the (expected) sequential number (bit) of the next message
|
||||
to be received in the provided channel.
|
||||
"""
|
||||
return (cache.sync & 0x40) >> 6
|
||||
|
||||
|
||||
def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
|
||||
"""
|
||||
Set the flag whether sending a message in this channel is allowed or not.
|
||||
"""
|
||||
cache.sync &= 0x7F
|
||||
if sending_allowed:
|
||||
cache.sync |= 0x80
|
||||
|
||||
|
||||
def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
"""
|
||||
Set the expected sequential number (bit) of the next message to be received
|
||||
in the provided channel
|
||||
"""
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
|
||||
if seq_bit not in (0, 1):
|
||||
raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# set second bit to "seq_bit" value
|
||||
cache.sync &= 0xBF
|
||||
if seq_bit:
|
||||
cache.sync |= 0x40
|
||||
|
||||
|
||||
def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
if seq_bit not in (0, 1):
|
||||
raise ThpError("Unexpected send seq bit")
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
|
||||
# set third bit to "seq_bit" value
|
||||
cache.sync &= 0xDF
|
||||
if seq_bit:
|
||||
cache.sync |= 0x20
|
||||
|
||||
|
||||
def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
|
||||
"""
|
||||
Set the sequential bit of the "next message to be send" to the opposite value,
|
||||
i.e. 1 -> 0 and 0 -> 1
|
||||
"""
|
||||
_set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
|
405
core/src/trezor/wire/thp/channel.py
Normal file
405
core/src/trezor/wire/thp/channel.py
Normal file
@ -0,0 +1,405 @@
|
||||
import ustruct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import (
|
||||
CHANNEL_HANDSHAKE_HASH,
|
||||
CHANNEL_KEY_RECEIVE,
|
||||
CHANNEL_KEY_SEND,
|
||||
CHANNEL_NONCE_RECEIVE,
|
||||
CHANNEL_NONCE_SEND,
|
||||
)
|
||||
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
|
||||
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
||||
from . import alternating_bit_protocol as ABP
|
||||
from . import (
|
||||
control_byte,
|
||||
crypto,
|
||||
interface_manager,
|
||||
memory_manager,
|
||||
received_message_handler,
|
||||
)
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .transmission_loop import TransmissionLoop
|
||||
from .writer import (
|
||||
CONT_HEADER_LENGTH,
|
||||
INIT_HEADER_LENGTH,
|
||||
write_payload_to_wire_and_add_checksum,
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
from . import state_to_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Awaitable
|
||||
|
||||
from .pairing_context import PairingContext
|
||||
from .session_context import GenericSessionContext
|
||||
|
||||
|
||||
class Channel:
|
||||
"""
|
||||
THP protocol encrypted communication channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "channel initialization")
|
||||
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
||||
self.channel_cache: ChannelCache = channel_cache
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read: int = 0
|
||||
self.buffer: utils.BufferType
|
||||
self.channel_id: bytes = channel_cache.channel_id
|
||||
self.selected_pairing_methods = []
|
||||
self.sessions: dict[int, GenericSessionContext] = {}
|
||||
self.write_task_spawn: loop.spawn | None = None
|
||||
self.connection_context: PairingContext | None = None
|
||||
self.transmission_loop: TransmissionLoop | None = None
|
||||
self.handshake: crypto.Handshake | None = None
|
||||
|
||||
def clear(self) -> None:
|
||||
clear_sessions_with_channel_id(self.channel_id)
|
||||
self.channel_cache.clear()
|
||||
|
||||
# ACCESS TO CHANNEL_DATA
|
||||
def get_channel_id_int(self) -> int:
|
||||
return int.from_bytes(self.channel_id, "big")
|
||||
|
||||
def get_channel_state(self) -> int:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) get_channel_state: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
state_to_str(state),
|
||||
)
|
||||
return state
|
||||
|
||||
def get_handshake_hash(self) -> bytes:
|
||||
h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
|
||||
assert h is not None
|
||||
return h
|
||||
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) set_channel_state: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
state_to_str(state),
|
||||
)
|
||||
|
||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) set_buffer: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
type(self.buffer),
|
||||
)
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) receive_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
|
||||
self._handle_received_packet(packet)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) self.buffer: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
utils.get_bytes_as_str(self.buffer),
|
||||
)
|
||||
|
||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||
self._finish_message()
|
||||
return received_message_handler.handle_received_message(self, self.buffer)
|
||||
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
||||
self.is_cont_packet_expected = True
|
||||
else:
|
||||
raise ThpError(
|
||||
"Read more bytes than is the expected length of the message!"
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
||||
ctrl_byte = packet[0]
|
||||
if control_byte.is_continuation(ctrl_byte):
|
||||
return self._handle_cont_packet(packet)
|
||||
return self._handle_init_packet(packet)
|
||||
|
||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
|
||||
_, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
self.expected_payload_length = payload_length
|
||||
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
|
||||
|
||||
# If the channel does not "own" the buffer lock, decrypt first packet
|
||||
# TODO do it only when needed!
|
||||
# TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
|
||||
# On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
|
||||
# in "memory_manager.select_buffer", because the session id is unknown (encrypted).
|
||||
|
||||
# if control_byte.is_encrypted_transport(ctrl_byte):
|
||||
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
||||
|
||||
self.buffer = memory_manager.select_buffer(
|
||||
self.get_channel_state(),
|
||||
self.buffer,
|
||||
packet_payload,
|
||||
payload_length,
|
||||
)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet - payload len: %d",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
payload_length,
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet - buffer len: %d",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
len(self.buffer),
|
||||
)
|
||||
return self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_cont_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||
|
||||
def _decrypt_single_packet_payload(
|
||||
self, payload: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
||||
return payload
|
||||
|
||||
def decrypt_buffer(
|
||||
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
||||
) -> None:
|
||||
noise_buffer = memoryview(self.buffer)[
|
||||
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||
]
|
||||
tag = self.buffer[
|
||||
message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
]
|
||||
if utils.DISABLE_ENCRYPTION:
|
||||
is_tag_valid = tag == crypto.DUMMY_TAG
|
||||
else:
|
||||
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
|
||||
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
|
||||
|
||||
assert key_receive is not None
|
||||
assert nonce_receive is not None
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Buffer before decryption: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
hexlify(noise_buffer),
|
||||
)
|
||||
is_tag_valid = crypto.dec(
|
||||
noise_buffer, tag, key_receive, nonce_receive, b""
|
||||
)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Buffer after decryption: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
hexlify(noise_buffer),
|
||||
)
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Is decrypted tag valid? %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
str(is_tag_valid),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Received tag: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
(hexlify(tag).decode()),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) New nonce_receive: %i",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
nonce_receive + 1,
|
||||
)
|
||||
|
||||
if not is_tag_valid:
|
||||
raise ThpDecryptionError()
|
||||
|
||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
|
||||
)
|
||||
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
|
||||
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
||||
|
||||
if utils.DISABLE_ENCRYPTION:
|
||||
tag = crypto.DUMMY_TAG
|
||||
else:
|
||||
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
|
||||
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
|
||||
|
||||
assert key_send is not None
|
||||
assert nonce_send is not None
|
||||
|
||||
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
|
||||
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
) -> None:
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self) -> None:
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
session_id: int = 0,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) write message: %s\n%s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
msg.MESSAGE_NAME,
|
||||
utils.dump_protobuf(msg),
|
||||
)
|
||||
|
||||
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
self.buffer, msg, session_id
|
||||
)
|
||||
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
|
||||
if task is not None:
|
||||
await task
|
||||
|
||||
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||
msg_data = err_type.to_bytes(1, "big")
|
||||
length = len(msg_data) + CHECKSUM_LENGTH
|
||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||
|
||||
def write_and_encrypt(
|
||||
self, payload: bytes, force: bool = False
|
||||
) -> Awaitable[None] | None:
|
||||
payload_length = len(payload)
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
|
||||
if self.write_task_spawn is not None:
|
||||
self.write_task_spawn.close() # UPS TODO might break something
|
||||
print("\nCLOSED\n")
|
||||
self._prepare_write()
|
||||
if force:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "Writing FORCE message (without async or retransmission)."
|
||||
)
|
||||
return self._write_encrypted_payload_loop(
|
||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(
|
||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||
)
|
||||
|
||||
def _prepare_write(self) -> None:
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
ABP.set_sending_allowed(self.channel_cache, False)
|
||||
|
||||
async def _write_encrypted_payload_loop(
|
||||
self, ctrl_byte: int, payload: bytes
|
||||
) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid %s) write_encrypted_payload_loop",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
|
||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
||||
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
||||
self.transmission_loop = TransmissionLoop(self, header, payload)
|
||||
await self.transmission_loop.start()
|
||||
|
||||
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
|
||||
|
||||
# Let the main loop be restarted and clear loop, if there is no other
|
||||
# workflow and the state is ENCRYPTED_TRANSPORT
|
||||
if self._can_clear_loop():
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) clearing loop from channel",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
loop.clear()
|
||||
|
||||
def _can_clear_loop(self) -> bool:
|
||||
return (
|
||||
not workflow.tasks
|
||||
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
34
core/src/trezor/wire/thp/channel_manager.py
Normal file
34
core/src/trezor/wire/thp/channel_manager.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from trezor import utils
|
||||
|
||||
from . import ChannelState, interface_manager
|
||||
from .channel import Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
|
||||
"""
|
||||
Creates a new channel for the interface `iface` with the buffer `buffer`.
|
||||
"""
|
||||
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
|
||||
r = Channel(channel_cache)
|
||||
r.set_buffer(buffer)
|
||||
r.set_channel_state(ChannelState.TH1)
|
||||
return r
|
||||
|
||||
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
|
||||
"""
|
||||
Returns all allocated channels from cache.
|
||||
"""
|
||||
channels: dict[int, Channel] = {}
|
||||
cached_channels = cache_thp.get_all_allocated_channels()
|
||||
for c in cached_channels:
|
||||
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
|
||||
for c in channels.values():
|
||||
c.set_buffer(buffer)
|
||||
return channels
|
22
core/src/trezor/wire/thp/checksum.py
Normal file
22
core/src/trezor/wire/thp/checksum.py
Normal file
@ -0,0 +1,22 @@
|
||||
from micropython import const
|
||||
|
||||
from trezor import utils
|
||||
from trezor.crypto import crc
|
||||
|
||||
CHECKSUM_LENGTH = const(4)
|
||||
|
||||
|
||||
def compute(data: bytes | utils.BufferType) -> bytes:
|
||||
"""
|
||||
Returns a CRC-32 checksum of the provided `data`.
|
||||
"""
|
||||
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||
|
||||
|
||||
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
|
||||
"""
|
||||
Checks whether the CRC-32 checksum of the `data` is the same
|
||||
as the checksum provided in `checksum`.
|
||||
"""
|
||||
data_checksum = compute(data)
|
||||
return checksum == data_checksum
|
50
core/src/trezor/wire/thp/control_byte.py
Normal file
50
core/src/trezor/wire/thp/control_byte.py
Normal file
@ -0,0 +1,50 @@
|
||||
from micropython import const
|
||||
|
||||
from . import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED,
|
||||
HANDSHAKE_COMP_REQ,
|
||||
HANDSHAKE_INIT_REQ,
|
||||
ThpError,
|
||||
)
|
||||
|
||||
_CONTINUATION_PACKET_MASK = const(0x80)
|
||||
_ACK_MASK = const(0xF7)
|
||||
_DATA_MASK = const(0xE7)
|
||||
|
||||
|
||||
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
|
||||
if seq_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if seq_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected sequence bit")
|
||||
|
||||
|
||||
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
|
||||
if ack_bit == 0:
|
||||
return ctrl_byte & 0xF7
|
||||
if ack_bit == 1:
|
||||
return ctrl_byte | 0x08
|
||||
raise ThpError("Unexpected acknowledgement bit")
|
||||
|
||||
|
||||
def is_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & _ACK_MASK == ACK_MESSAGE
|
||||
|
||||
|
||||
def is_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & _DATA_MASK == ENCRYPTED
|
||||
|
||||
|
||||
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ
|
||||
|
||||
|
||||
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ
|
36
core/src/trezor/wire/thp/cpace.py
Normal file
36
core/src/trezor/wire/thp/cpace.py
Normal file
@ -0,0 +1,36 @@
|
||||
from trezor.crypto import elligator2, random
|
||||
from trezor.crypto.curve import curve25519
|
||||
from trezor.crypto.hashlib import sha512
|
||||
|
||||
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
|
||||
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
|
||||
|
||||
|
||||
class Cpace:
|
||||
"""
|
||||
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
|
||||
"""
|
||||
|
||||
def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None:
|
||||
self.handshake_hash: bytes = handshake_hash
|
||||
self.host_public_key: bytes = cpace_host_public_key
|
||||
self.shared_secret: bytes
|
||||
self.trezor_private_key: bytes
|
||||
self.trezor_public_key: bytes
|
||||
|
||||
def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
|
||||
"""
|
||||
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
|
||||
"""
|
||||
sha_ctx = sha512(_PREFIX)
|
||||
sha_ctx.update(code_code_entry)
|
||||
sha_ctx.update(_PADDING)
|
||||
sha_ctx.update(self.handshake_hash)
|
||||
sha_ctx.update(b"\x00")
|
||||
pregenerator = sha_ctx.digest()[:32]
|
||||
generator = elligator2.map_to_curve25519(pregenerator)
|
||||
self.trezor_private_key = random.bytes(32)
|
||||
self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
|
||||
self.shared_secret = curve25519.multiply(
|
||||
self.trezor_private_key, self.host_public_key
|
||||
)
|
211
core/src/trezor/wire/thp/crypto.py
Normal file
211
core/src/trezor/wire/thp/crypto.py
Normal file
@ -0,0 +1,211 @@
|
||||
from micropython import const
|
||||
from trezorcrypto import aesgcm, bip32, curve25519, hmac
|
||||
|
||||
from storage import device
|
||||
from trezor import log, utils
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.wire.thp import ThpDecryptionError
|
||||
|
||||
# The HARDENED flag is taken from apps.common.paths
|
||||
# It is not imported to save on resources
|
||||
HARDENED = const(0x8000_0000)
|
||||
PUBKEY_LENGTH = const(32)
|
||||
if utils.DISABLE_ENCRYPTION:
|
||||
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
|
||||
def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
|
||||
"""
|
||||
Encrypts the provided `buffer` with AES-GCM (in place).
|
||||
Returns a 16-byte long encryption tag.
|
||||
"""
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce)
|
||||
iv = _get_iv_from_nonce(nonce)
|
||||
aes_ctx = aesgcm(key, iv)
|
||||
aes_ctx.auth(auth_data)
|
||||
aes_ctx.encrypt_in_place(buffer)
|
||||
return aes_ctx.finish()
|
||||
|
||||
|
||||
def dec(
|
||||
buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
|
||||
) -> bool:
|
||||
"""
|
||||
Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
|
||||
the tag computed in decryption, otherwise it returns `False`.
|
||||
"""
|
||||
iv = _get_iv_from_nonce(nonce)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce)
|
||||
aes_ctx = aesgcm(key, iv)
|
||||
aes_ctx.auth(auth_data)
|
||||
aes_ctx.decrypt_in_place(buffer)
|
||||
computed_tag = aes_ctx.finish()
|
||||
return computed_tag == tag
|
||||
|
||||
|
||||
class BusyDecoder:
|
||||
def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
|
||||
iv = _get_iv_from_nonce(nonce)
|
||||
self.aes_ctx = aesgcm(key, iv)
|
||||
self.aes_ctx.auth(auth_data)
|
||||
|
||||
def decrypt_part(self, part: utils.BufferType) -> None:
|
||||
self.aes_ctx.decrypt_in_place(part)
|
||||
|
||||
def finish_and_check_tag(self, tag: bytes) -> bool:
|
||||
computed_tag = self.aes_ctx.finish()
|
||||
return computed_tag == tag
|
||||
|
||||
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
|
||||
|
||||
class Handshake:
|
||||
"""
|
||||
`Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
|
||||
The following values should be saved for future use before disposing of this object:
|
||||
- `h` - handshake hash, can be used to bind other values to the channel
|
||||
- `key_receive` - key for decrypting incoming communication
|
||||
- `key_send` - key for encrypting outgoing communication
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trezor_ephemeral_privkey: bytes
|
||||
self.ck: bytes
|
||||
self.k: bytes
|
||||
self.h: bytes
|
||||
self.key_receive: bytes
|
||||
self.key_send: bytes
|
||||
|
||||
def handle_th1_crypto(
|
||||
self,
|
||||
device_properties: bytes,
|
||||
host_ephemeral_pubkey: bytes,
|
||||
) -> tuple[bytes, bytes, bytes]:
|
||||
|
||||
trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
|
||||
self.trezor_ephemeral_privkey = curve25519.generate_secret()
|
||||
trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
|
||||
self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
|
||||
self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
|
||||
self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
|
||||
point = curve25519.multiply(
|
||||
self.trezor_ephemeral_privkey, host_ephemeral_pubkey
|
||||
)
|
||||
self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
|
||||
mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
|
||||
trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
|
||||
aes_ctx = aesgcm(self.k, IV_1)
|
||||
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
|
||||
if __debug__:
|
||||
log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0)
|
||||
aes_ctx.auth(self.h)
|
||||
tag_to_encrypted_key = aes_ctx.finish()
|
||||
encrypted_trezor_static_pubkey = (
|
||||
encrypted_trezor_static_pubkey + tag_to_encrypted_key
|
||||
)
|
||||
self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
|
||||
point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
|
||||
self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
|
||||
aes_ctx = aesgcm(self.k, IV_1)
|
||||
aes_ctx.auth(self.h)
|
||||
tag = aes_ctx.finish()
|
||||
self.h = _hash_of_two(self.h, tag)
|
||||
return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
|
||||
|
||||
def handle_th2_crypto(
|
||||
self,
|
||||
encrypted_host_static_pubkey: utils.BufferType,
|
||||
encrypted_payload: utils.BufferType,
|
||||
) -> None:
|
||||
|
||||
aes_ctx = aesgcm(self.k, IV_2)
|
||||
|
||||
# The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
|
||||
# However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
|
||||
# authentication of the gcm tag.
|
||||
aes_ctx.auth(self.h) # Authenticate with the previous value of `h`
|
||||
self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value
|
||||
aes_ctx.decrypt_in_place(
|
||||
memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
|
||||
)
|
||||
if __debug__:
|
||||
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1)
|
||||
host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
|
||||
tag = aes_ctx.finish()
|
||||
if tag != encrypted_host_static_pubkey[-16:]:
|
||||
raise ThpDecryptionError()
|
||||
|
||||
self.ck, self.k = _hkdf(
|
||||
self.ck,
|
||||
curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
|
||||
)
|
||||
aes_ctx = aesgcm(self.k, IV_1)
|
||||
aes_ctx.auth(self.h)
|
||||
aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
|
||||
if __debug__:
|
||||
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0)
|
||||
tag = aes_ctx.finish()
|
||||
if tag != encrypted_payload[-16:]:
|
||||
raise ThpDecryptionError()
|
||||
|
||||
self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
|
||||
self.key_receive, self.key_send = _hkdf(self.ck, b"")
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(key_receive: %s, key_send: %s)",
|
||||
hexlify(self.key_receive),
|
||||
hexlify(self.key_send),
|
||||
)
|
||||
|
||||
def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
|
||||
aes_ctx = aesgcm(self.key_send, IV_1)
|
||||
encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
|
||||
tag = aes_ctx.finish()
|
||||
return encrypted_trezor_state + tag
|
||||
|
||||
|
||||
def _derive_static_key_pair() -> tuple[bytes, bytes]:
|
||||
node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
|
||||
node = bip32.from_seed(device.get_device_secret(), "curve25519")
|
||||
node.derive(node_int)
|
||||
|
||||
trezor_static_privkey = node.private_key()
|
||||
trezor_static_pubkey = node.public_key()[1:33]
|
||||
# Note: the first byte (\x01) of the public key is removed, as it
|
||||
# only indicates the type of the elliptic curve used
|
||||
|
||||
return trezor_static_privkey, trezor_static_pubkey
|
||||
|
||||
|
||||
def get_trezor_static_pubkey() -> bytes:
|
||||
_, pubkey = _derive_static_key_pair()
|
||||
return pubkey
|
||||
|
||||
|
||||
def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]:
|
||||
temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
|
||||
output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
|
||||
ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
|
||||
ctx_output_2.update(b"\x02")
|
||||
output_2 = ctx_output_2.digest()
|
||||
return (output_1, output_2)
|
||||
|
||||
|
||||
def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
|
||||
ctx = sha256(part_1)
|
||||
ctx.update(part_2)
|
||||
return ctx.digest()
|
||||
|
||||
|
||||
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
28
core/src/trezor/wire/thp/interface_manager.py
Normal file
28
core/src/trezor/wire/thp/interface_manager.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import usb
|
||||
|
||||
_WIRE_INTERFACE_USB = b"\x01"
|
||||
# TODO _WIRE_INTERFACE_BLE = b"\x02"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
def decode_iface(cached_iface: bytes) -> WireInterface:
|
||||
"""Decode the cached wire interface."""
|
||||
if cached_iface == _WIRE_INTERFACE_USB:
|
||||
iface = usb.iface_wire
|
||||
if iface is None:
|
||||
raise RuntimeError("There is no valid USB WireInterface")
|
||||
return iface
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
def encode_iface(iface: WireInterface) -> bytes:
|
||||
"""Encode wire interface into bytes."""
|
||||
if iface is usb.iface_wire:
|
||||
return _WIRE_INTERFACE_USB
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
179
core/src/trezor/wire/thp/memory_manager.py
Normal file
179
core/src/trezor/wire/thp/memory_manager.py
Normal file
@ -0,0 +1,179 @@
|
||||
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
||||
from trezor import log, protobuf, utils
|
||||
from trezor.wire.message_handler import get_msg_type
|
||||
|
||||
from . import ChannelState, ThpError
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .writer import (
|
||||
INIT_HEADER_LENGTH,
|
||||
MAX_PAYLOAD_LEN,
|
||||
MESSAGE_TYPE_LENGTH,
|
||||
PACKET_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def select_buffer(
|
||||
channel_state: int,
|
||||
channel_buffer: utils.BufferType,
|
||||
packet_payload: utils.BufferType,
|
||||
payload_length: int,
|
||||
) -> utils.BufferType:
|
||||
|
||||
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
session_id = packet_payload[0]
|
||||
if session_id == 0:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
else:
|
||||
pass
|
||||
# TODO use big buffer but only if the channel owns the buffer lock.
|
||||
# Otherwise send BUSY message and return
|
||||
else:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
try:
|
||||
# TODO for now, we create a new big buffer every time. It should be changed
|
||||
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
|
||||
return buffer
|
||||
except Exception as e:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.exception(__name__, e)
|
||||
raise Exception("Failed to create a buffer for channel") # TODO handle better
|
||||
|
||||
|
||||
def get_write_buffer(
|
||||
buffer: utils.BufferType, msg: protobuf.MessageType
|
||||
) -> utils.BufferType:
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
||||
|
||||
if required_min_size > len(buffer):
|
||||
return _get_buffer_for_write(required_min_size, buffer)
|
||||
return buffer
|
||||
|
||||
|
||||
def encode_into_buffer(
|
||||
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
||||
) -> int:
|
||||
# cannot write message without wire type
|
||||
msg_type = msg.MESSAGE_WIRE_TYPE
|
||||
if msg_type is None:
|
||||
msg_type = get_msg_type(msg.MESSAGE_NAME)
|
||||
assert msg_type is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||
|
||||
_encode_session_into_buffer(memoryview(buffer), session_id)
|
||||
_encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH)
|
||||
_encode_message_into_buffer(
|
||||
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
)
|
||||
|
||||
return payload_size
|
||||
|
||||
|
||||
def _encode_session_into_buffer(
|
||||
buffer: memoryview, session_id: int, buffer_offset: int = 0
|
||||
) -> None:
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_type_into_buffer(
|
||||
buffer: memoryview, message_type: int, offset: int = 0
|
||||
) -> None:
|
||||
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
|
||||
utils.memcpy(buffer, offset, msg_type_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_into_buffer(
|
||||
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
|
||||
) -> None:
|
||||
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
|
||||
|
||||
|
||||
def _get_buffer_for_read(
|
||||
payload_length: int,
|
||||
existing_buffer: utils.BufferType,
|
||||
max_length: int = MAX_PAYLOAD_LEN,
|
||||
) -> utils.BufferType:
|
||||
length = payload_length + INIT_HEADER_LENGTH
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"get_buffer_for_read - length: %d, %s %s",
|
||||
length,
|
||||
"existing buffer type:",
|
||||
type(existing_buffer),
|
||||
)
|
||||
if length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
|
||||
if length > len(existing_buffer):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Allocating a new buffer")
|
||||
|
||||
from .thp_main import get_raw_read_buffer
|
||||
|
||||
if length > len(get_raw_read_buffer()):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Required length is %d, where raw buffer has capacity only %d",
|
||||
length,
|
||||
len(get_raw_read_buffer()),
|
||||
)
|
||||
raise ThpError("Message is too large")
|
||||
|
||||
try:
|
||||
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
|
||||
except MemoryError:
|
||||
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
|
||||
raise ThpError("Message is too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Reusing already allocated buffer")
|
||||
return memoryview(existing_buffer)[:length]
|
||||
|
||||
|
||||
def _get_buffer_for_write(
|
||||
payload_length: int,
|
||||
existing_buffer: utils.BufferType,
|
||||
max_length: int = MAX_PAYLOAD_LEN,
|
||||
) -> utils.BufferType:
|
||||
length = payload_length + INIT_HEADER_LENGTH
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"get_buffer_for_write - length: %d, %s %s",
|
||||
length,
|
||||
"existing buffer type:",
|
||||
type(existing_buffer),
|
||||
)
|
||||
if length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
|
||||
if length > len(existing_buffer):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
||||
|
||||
from .thp_main import get_raw_write_buffer
|
||||
|
||||
if length > len(get_raw_write_buffer()):
|
||||
raise ThpError("Message is too large")
|
||||
|
||||
try:
|
||||
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
|
||||
except MemoryError:
|
||||
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
|
||||
raise ThpError("Message is too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Reusing already allocated buffer")
|
||||
return memoryview(existing_buffer)[:length]
|
262
core/src/trezor/wire/thp/pairing_context.py
Normal file
262
core/src/trezor/wire/thp/pairing_context.py
Normal file
@ -0,0 +1,262 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
import trezorui2
|
||||
from trezor import loop, protobuf, workflow
|
||||
from trezor.crypto import random
|
||||
from trezor.wire import context, message_handler, protocol_common
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.errors import ActionCancelled, SilentError
|
||||
from trezor.wire.protocol_common import Context, Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Container
|
||||
|
||||
from trezor import ui
|
||||
|
||||
from .channel import Channel
|
||||
from .cpace import Cpace
|
||||
|
||||
pass
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
|
||||
class PairingDisplayData:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.code_code_entry: int | None = None
|
||||
self.code_qr_code: bytes | None = None
|
||||
self.code_nfc_unidirectional: bytes | None = None
|
||||
|
||||
def get_display_layout(self) -> ui.Layout:
|
||||
from trezor import ui
|
||||
|
||||
# TODO have different layouts when there is only QR code or only Code Entry
|
||||
qr_str = ""
|
||||
code_str = ""
|
||||
if self.code_qr_code is not None:
|
||||
qr_str = self._get_code_qr_code_str()
|
||||
if self.code_code_entry is not None:
|
||||
code_str = self._get_code_code_entry_str()
|
||||
|
||||
return ui.Layout(
|
||||
trezorui2.show_address_details( # noqa
|
||||
qr_title="Scan QR code to pair",
|
||||
address=qr_str,
|
||||
case_sensitive=True,
|
||||
details_title="",
|
||||
account="Code to rewrite:\n" + code_str,
|
||||
path="",
|
||||
xpubs=[],
|
||||
)
|
||||
)
|
||||
|
||||
def _get_code_code_entry_str(self) -> str:
|
||||
if self.code_code_entry is not None:
|
||||
code_str = f"{self.code_code_entry:06}"
|
||||
if __debug__:
|
||||
log.debug(__name__, "code_code_entry: %s", code_str)
|
||||
|
||||
return code_str[:3] + " " + code_str[3:]
|
||||
raise Exception("Code entry string is not available")
|
||||
|
||||
def _get_code_qr_code_str(self) -> str:
|
||||
if self.code_qr_code is not None:
|
||||
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
|
||||
if __debug__:
|
||||
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
|
||||
return code_str
|
||||
raise Exception("QR code string is not available")
|
||||
|
||||
|
||||
class PairingContext(Context):
|
||||
|
||||
def __init__(self, channel_ctx: Channel) -> None:
|
||||
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
|
||||
self.channel_ctx: Channel = channel_ctx
|
||||
self.incoming_message = loop.mailbox()
|
||||
self.secret: bytes = random.bytes(16)
|
||||
|
||||
self.display_data: PairingDisplayData = PairingDisplayData()
|
||||
self.cpace: Cpace
|
||||
self.host_name: str
|
||||
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
# if __debug__:
|
||||
# log.debug(__name__, "handle - start")
|
||||
# if is_debug_session:
|
||||
# import apps.debug
|
||||
|
||||
# apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
next_message: Message | None = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
if next_message is None:
|
||||
# If the previous run did not keep an unprocessed message for us,
|
||||
# wait for a new one.
|
||||
try:
|
||||
message: Message = await self.incoming_message
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
await self.write(message_handler.failure(e))
|
||||
continue
|
||||
else:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
|
||||
try:
|
||||
next_message = await handle_pairing_request_message(self, message)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# This is not done for the debug session because the snapshot taken
|
||||
# in a debug session would clear modules which are in use by the
|
||||
# workflow running on wire.
|
||||
# TODO utils.unimport_end(modules)
|
||||
|
||||
if next_message is None:
|
||||
|
||||
# Shut down the loop if there is no next message waiting.
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
except Exception as exc:
|
||||
# Log and try again. The session handler can only exit explicitly via
|
||||
# loop.clear() above. # TODO not updated comments
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
|
||||
message: Message = await self.incoming_message
|
||||
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageException(message)
|
||||
|
||||
if expected_type is None:
|
||||
name = message_handler.get_msg_name(message.type)
|
||||
if name is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
else:
|
||||
expected_type = protobuf.type_for_name(name)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_ctx.write(msg)
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||
) -> protobuf.MessageType:
|
||||
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
|
||||
if expected_wire_type is None:
|
||||
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
|
||||
|
||||
assert expected_wire_type is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
|
||||
return await self.read((expected_wire_type,), expected_type)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_types: int
|
||||
) -> protobuf.MessageType:
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read(expected_types)
|
||||
|
||||
|
||||
async def handle_pairing_request_message(
|
||||
pairing_ctx: PairingContext,
|
||||
msg: protocol_common.Message,
|
||||
) -> protocol_common.Message | None:
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
from apps.thp.pairing import handle_pairing_request
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
name = message_handler.get_msg_name(msg.type)
|
||||
if name is None:
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
else:
|
||||
req_type = protobuf.type_for_name(name)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handle_pairing_request(pairing_ctx, req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
|
||||
|
||||
except UnexpectedMessageException as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
return exc.msg
|
||||
except SilentError as exc:
|
||||
if __debug__:
|
||||
log.error(__name__, "SilentError: %s", exc.message)
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = message_handler.failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await pairing_ctx.write(res_msg)
|
||||
return None
|
446
core/src/trezor/wire/thp/received_message_handler.py
Normal file
446
core/src/trezor/wire/thp/received_message_handler.py
Normal file
@ -0,0 +1,446 @@
|
||||
import ustruct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import (
|
||||
CHANNEL_HANDSHAKE_HASH,
|
||||
CHANNEL_KEY_RECEIVE,
|
||||
CHANNEL_KEY_SEND,
|
||||
CHANNEL_NONCE_RECEIVE,
|
||||
CHANNEL_NONCE_SEND,
|
||||
)
|
||||
from storage.cache_thp import (
|
||||
KEY_LENGTH,
|
||||
MANAGEMENT_SESSION_ID,
|
||||
SESSION_ID_LENGTH,
|
||||
TAG_LENGTH,
|
||||
update_channel_last_used,
|
||||
update_session_last_used,
|
||||
)
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
|
||||
from .. import message_handler
|
||||
from ..errors import DataError
|
||||
from ..protocol_common import Message
|
||||
from . import (
|
||||
ACK_MESSAGE,
|
||||
HANDSHAKE_COMP_RES,
|
||||
HANDSHAKE_INIT_RES,
|
||||
ChannelState,
|
||||
PacketHeader,
|
||||
SessionState,
|
||||
ThpDecryptionError,
|
||||
ThpError,
|
||||
ThpErrorType,
|
||||
ThpInvalidDataError,
|
||||
ThpUnallocatedSessionError,
|
||||
)
|
||||
from . import alternating_bit_protocol as ABP
|
||||
from . import (
|
||||
checksum,
|
||||
control_byte,
|
||||
get_enabled_pairing_methods,
|
||||
get_encoded_device_properties,
|
||||
session_manager,
|
||||
)
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .crypto import PUBKEY_LENGTH, Handshake
|
||||
from .writer import (
|
||||
INIT_HEADER_LENGTH,
|
||||
MESSAGE_TYPE_LENGTH,
|
||||
write_payload_to_wire_and_add_checksum,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable
|
||||
|
||||
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
|
||||
|
||||
from .channel import Channel
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
from . import state_to_str
|
||||
|
||||
|
||||
_TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
_TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
|
||||
async def handle_received_message(
|
||||
ctx: Channel, message_buffer: utils.BufferType
|
||||
) -> None:
|
||||
"""Handle a message received from the channel."""
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "handle_received_message")
|
||||
if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done
|
||||
try:
|
||||
import micropython
|
||||
|
||||
print("micropython.mem_info() from received_message_handler.py")
|
||||
micropython.mem_info()
|
||||
print("Allocation count:", micropython.alloc_count()) # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
|
||||
except AttributeError:
|
||||
print(
|
||||
"To show allocation count, create the build with TREZOR_MEMPERF=1"
|
||||
)
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
||||
message_length = payload_length + INIT_HEADER_LENGTH
|
||||
|
||||
_check_checksum(message_length, message_buffer)
|
||||
|
||||
# Synchronization process
|
||||
seq_bit = (ctrl_byte & 0x10) >> 4
|
||||
ack_bit = (ctrl_byte & 0x08) >> 3
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
|
||||
seq_bit,
|
||||
ack_bit,
|
||||
)
|
||||
# 0: Update "last-time used"
|
||||
update_channel_last_used(ctx.channel_id)
|
||||
|
||||
# 1: Handle ACKs
|
||||
if control_byte.is_ack(ctrl_byte):
|
||||
await _handle_ack(ctx, ack_bit)
|
||||
return
|
||||
|
||||
if _should_have_ctrl_byte_encrypted_transport(
|
||||
ctx
|
||||
) and not control_byte.is_encrypted_transport(ctrl_byte):
|
||||
raise ThpError("Message is not encrypted. Ignoring")
|
||||
|
||||
# 2: Handle message with unexpected sequential bit
|
||||
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Received message with an unexpected sequential bit")
|
||||
await _send_ack(ctx, ack_bit=seq_bit)
|
||||
raise ThpError("Received message with an unexpected sequential bit")
|
||||
|
||||
# 3: Send ACK in response
|
||||
await _send_ack(ctx, ack_bit=seq_bit)
|
||||
|
||||
ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
|
||||
|
||||
try:
|
||||
await _handle_message_to_app_or_channel(
|
||||
ctx, payload_length, message_length, ctrl_byte
|
||||
)
|
||||
except ThpUnallocatedSessionError as e:
|
||||
error_message = Failure(code=FailureType.ThpUnallocatedSession)
|
||||
await ctx.write(error_message, e.session_id)
|
||||
except ThpDecryptionError:
|
||||
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
|
||||
ctx.clear()
|
||||
except ThpInvalidDataError:
|
||||
await ctx.write_error(ThpErrorType.INVALID_DATA)
|
||||
ctx.clear()
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "handle_received_message - end")
|
||||
|
||||
|
||||
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
||||
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, ack_bit: %d",
|
||||
ctx.get_channel_id_int(),
|
||||
ack_bit,
|
||||
)
|
||||
return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
|
||||
|
||||
|
||||
def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "check_checksum")
|
||||
if not checksum.is_valid(
|
||||
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
|
||||
):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Invalid checksum, ignoring message.")
|
||||
raise ThpError("Invalid checksum, ignoring message.")
|
||||
|
||||
|
||||
async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
|
||||
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
|
||||
return
|
||||
# ACK is expected and it has correct sync bit
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Received ACK message with correct ack bit")
|
||||
if ctx.transmission_loop is not None:
|
||||
ctx.transmission_loop.stop_immediately()
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Stopped transmission loop")
|
||||
|
||||
ABP.set_sending_allowed(ctx.channel_cache, True)
|
||||
|
||||
if ctx.write_task_spawn is not None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
||||
await ctx.write_task_spawn
|
||||
# Note that no the write_task_spawn could result in loop.clear(),
|
||||
# which will result in termination of this function - any code after
|
||||
# this await might not be executed
|
||||
|
||||
|
||||
def _handle_message_to_app_or_channel(
|
||||
ctx: Channel,
|
||||
payload_length: int,
|
||||
message_length: int,
|
||||
ctrl_byte: int,
|
||||
) -> Awaitable[None]:
|
||||
state = ctx.get_channel_state()
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "state: %s", state_to_str(state))
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
return _handle_state_TH2(ctx, message_length, ctrl_byte)
|
||||
|
||||
if _is_channel_state_pairing(state):
|
||||
return _handle_pairing(ctx, message_length)
|
||||
|
||||
raise ThpError("Unimplemented channel state")
|
||||
|
||||
|
||||
async def _handle_state_TH1(
|
||||
ctx: Channel,
|
||||
payload_length: int,
|
||||
message_length: int,
|
||||
ctrl_byte: int,
|
||||
) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "handle_state_TH1")
|
||||
if not control_byte.is_handshake_init_req(ctrl_byte):
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError("Message received is not a valid handshake init request!")
|
||||
|
||||
ctx.handshake = Handshake()
|
||||
|
||||
host_ephemeral_pubkey = bytearray(
|
||||
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
|
||||
)
|
||||
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
|
||||
ctx.handshake.handle_th1_crypto(
|
||||
get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey
|
||||
)
|
||||
)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"trezor ephemeral pubkey: %s",
|
||||
hexlify(trezor_ephemeral_pubkey).decode(),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"encrypted trezor masked static pubkey: %s",
|
||||
hexlify(encrypted_trezor_static_pubkey).decode(),
|
||||
)
|
||||
log.debug(__name__, "tag: %s", hexlify(tag))
|
||||
|
||||
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
||||
|
||||
# send handshake init response message
|
||||
ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
|
||||
ctx.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
|
||||
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
|
||||
from apps.thp.credential_manager import validate_credential
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "handle_state_TH2")
|
||||
if not control_byte.is_handshake_comp_req(ctrl_byte):
|
||||
raise ThpError("Message received is not a handshake completion request!")
|
||||
if ctx.handshake is None:
|
||||
raise Exception("Handshake object is not prepared. Retry handshake.")
|
||||
|
||||
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
|
||||
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
|
||||
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
|
||||
]
|
||||
|
||||
ctx.handshake.handle_th2_crypto(
|
||||
host_encrypted_static_pubkey, handshake_completion_request_noise_payload
|
||||
)
|
||||
|
||||
ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
|
||||
ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
|
||||
ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
|
||||
ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
|
||||
|
||||
noise_payload = _decode_message(
|
||||
ctx.buffer[
|
||||
INIT_HEADER_LENGTH
|
||||
+ KEY_LENGTH
|
||||
+ TAG_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
0,
|
||||
"ThpHandshakeCompletionReqNoisePayload",
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
|
||||
enabled_methods = get_enabled_pairing_methods(ctx.iface)
|
||||
for method in noise_payload.pairing_methods:
|
||||
if method not in enabled_methods:
|
||||
raise ThpInvalidDataError()
|
||||
if method not in ctx.selected_pairing_methods:
|
||||
ctx.selected_pairing_methods.append(method)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"host static pubkey: %s, noise payload: %s",
|
||||
utils.get_bytes_as_str(host_encrypted_static_pubkey),
|
||||
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
|
||||
)
|
||||
|
||||
# key is decoded in handshake._handle_th2_crypto
|
||||
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
|
||||
|
||||
paired: bool = False
|
||||
|
||||
if noise_payload.host_pairing_credential is not None:
|
||||
try: # TODO change try-except for something better
|
||||
paired = validate_credential(
|
||||
noise_payload.host_pairing_credential,
|
||||
host_static_pubkey,
|
||||
)
|
||||
except DataError as e:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.exception(__name__, e)
|
||||
pass
|
||||
|
||||
trezor_state = _TREZOR_STATE_UNPAIRED
|
||||
if paired:
|
||||
trezor_state = _TREZOR_STATE_PAIRED
|
||||
# send hanshake completion response
|
||||
ctx.write_handshake_message(
|
||||
HANDSHAKE_COMP_RES,
|
||||
ctx.handshake.get_handshake_completion_response(trezor_state),
|
||||
)
|
||||
|
||||
ctx.handshake = None
|
||||
|
||||
if paired:
|
||||
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
else:
|
||||
ctx.set_channel_state(ChannelState.TP1)
|
||||
|
||||
|
||||
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
||||
|
||||
ctx.decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(
|
||||
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
|
||||
)
|
||||
if session_id not in ctx.sessions:
|
||||
if session_id == MANAGEMENT_SESSION_ID:
|
||||
s = session_manager.create_new_management_session(ctx)
|
||||
else:
|
||||
s = session_manager.get_session_from_cache(ctx, session_id)
|
||||
if s is None:
|
||||
raise ThpUnallocatedSessionError(session_id)
|
||||
ctx.sessions[session_id] = s
|
||||
loop.schedule(s.handle())
|
||||
|
||||
elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED:
|
||||
raise ThpUnallocatedSessionError(session_id)
|
||||
|
||||
s = ctx.sessions[session_id]
|
||||
update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big"))
|
||||
|
||||
s.incoming_message.put(
|
||||
Message(
|
||||
message_type,
|
||||
ctx.buffer[
|
||||
INIT_HEADER_LENGTH
|
||||
+ MESSAGE_TYPE_LENGTH
|
||||
+ SESSION_ID_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _handle_pairing(ctx: Channel, message_length: int) -> None:
|
||||
from .pairing_context import PairingContext
|
||||
|
||||
if ctx.connection_context is None:
|
||||
ctx.connection_context = PairingContext(ctx)
|
||||
loop.schedule(ctx.connection_context.handle())
|
||||
|
||||
ctx.decrypt_buffer(message_length)
|
||||
message_type = ustruct.unpack(
|
||||
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
|
||||
)[0]
|
||||
|
||||
ctx.connection_context.incoming_message.put(
|
||||
Message(
|
||||
message_type,
|
||||
ctx.buffer[
|
||||
INIT_HEADER_LENGTH
|
||||
+ MESSAGE_TYPE_LENGTH
|
||||
+ SESSION_ID_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
|
||||
if ctx.get_channel_state() in [
|
||||
ChannelState.UNALLOCATED,
|
||||
ChannelState.TH1,
|
||||
ChannelState.TH2,
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _decode_message(
|
||||
buffer: bytes, msg_type: int, message_name: str | None = None
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
log.debug(__name__, "decode message")
|
||||
if message_name is not None:
|
||||
expected_type = protobuf.type_for_name(message_name)
|
||||
else:
|
||||
expected_type = protobuf.type_for_wire(msg_type)
|
||||
return message_handler.wrap_protobuf_load(buffer, expected_type)
|
||||
|
||||
|
||||
def _is_channel_state_pairing(state: int) -> bool:
|
||||
if state in (
|
||||
ChannelState.TP1,
|
||||
ChannelState.TP2,
|
||||
ChannelState.TP3,
|
||||
ChannelState.TP4,
|
||||
ChannelState.TC1,
|
||||
):
|
||||
return True
|
||||
return False
|
175
core/src/trezor/wire/thp/session_context.py
Normal file
175
core/src/trezor/wire/thp/session_context.py
Normal file
@ -0,0 +1,175 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.wire import message_handler, protocol_common
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.message_handler import failure, find_handler
|
||||
|
||||
from ..protocol_common import Context, Message
|
||||
from . import SessionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Container
|
||||
|
||||
from storage.cache_common import DataCache
|
||||
|
||||
from ..message_handler import HandlerFinder
|
||||
from .channel import Channel
|
||||
|
||||
pass
|
||||
|
||||
_EXIT_LOOP = True
|
||||
_REPEAT_LOOP = False
|
||||
|
||||
if __debug__:
|
||||
from trezor.utils import get_bytes_as_str
|
||||
|
||||
|
||||
class GenericSessionContext(Context):
|
||||
|
||||
def __init__(self, channel: Channel, session_id: int) -> None:
|
||||
super().__init__(channel.iface, channel.channel_id)
|
||||
self.channel: Channel = channel
|
||||
self.session_id: int = session_id
|
||||
self.incoming_message = loop.mailbox()
|
||||
self.handler_finder: HandlerFinder = find_handler
|
||||
|
||||
async def handle(self) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._handle_debug()
|
||||
|
||||
next_message: Message | None = None
|
||||
|
||||
while True:
|
||||
message = next_message
|
||||
next_message = None
|
||||
try:
|
||||
if await self._handle_message(message):
|
||||
loop.schedule(self.handle())
|
||||
return
|
||||
except UnexpectedMessageException as unexpected:
|
||||
# The workflow was interrupted by an unexpected message. We need to
|
||||
# process it as if it was a new message...
|
||||
next_message = unexpected.msg
|
||||
continue
|
||||
except Exception as exc:
|
||||
# Log and try again.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
def _handle_debug(self) -> None:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle - start (channel_id (bytes): %s, session_id: %d)",
|
||||
get_bytes_as_str(self.channel_id),
|
||||
self.session_id,
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
next_message: Message | None,
|
||||
) -> bool:
|
||||
|
||||
try:
|
||||
if next_message is not None:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
else:
|
||||
# Wait for a new message from wire
|
||||
message = await self.incoming_message
|
||||
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.exception(__name__, e)
|
||||
await self.write(failure(e))
|
||||
return _REPEAT_LOOP
|
||||
|
||||
await message_handler.handle_single_message(
|
||||
self,
|
||||
message,
|
||||
self.handler_finder,
|
||||
)
|
||||
return _EXIT_LOOP
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
message: Message = await self.incoming_message
|
||||
if message.type not in expected_types:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
|
||||
str(expected_types),
|
||||
str(message.type),
|
||||
)
|
||||
raise UnexpectedMessageException(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel.write(msg, self.session_id)
|
||||
|
||||
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
||||
return self.channel.write(msg, self.session_id, force=True)
|
||||
|
||||
def get_session_state(self) -> SessionState: ...
|
||||
|
||||
|
||||
class ManagementSessionContext(GenericSessionContext):
|
||||
|
||||
def __init__(
|
||||
self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID
|
||||
) -> None:
|
||||
super().__init__(channel_ctx, session_id)
|
||||
|
||||
def get_session_state(self) -> SessionState:
|
||||
return SessionState.MANAGEMENT
|
||||
|
||||
|
||||
class SessionContext(GenericSessionContext):
|
||||
|
||||
def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
|
||||
if channel_ctx.channel_id != session_cache.channel_id:
|
||||
raise Exception(
|
||||
"The session has different channel id than the provided channel context!"
|
||||
)
|
||||
session_id = int.from_bytes(session_cache.session_id, "big")
|
||||
super().__init__(channel_ctx, session_id)
|
||||
self.session_cache = session_cache
|
||||
|
||||
# ACCESS TO SESSION DATA
|
||||
|
||||
def get_session_state(self) -> SessionState:
|
||||
state = int.from_bytes(self.session_cache.state, "big")
|
||||
return SessionState(state)
|
||||
|
||||
def set_session_state(self, state: SessionState) -> None:
|
||||
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
|
||||
def release(self) -> None:
|
||||
if self.session_cache is not None:
|
||||
cache_thp.clear_session(self.session_cache)
|
||||
|
||||
# ACCESS TO CACHE
|
||||
@property
|
||||
def cache(self) -> DataCache:
|
||||
return self.session_cache
|
37
core/src/trezor/wire/thp/session_manager.py
Normal file
37
core/src/trezor/wire/thp/session_manager.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
|
||||
from .session_context import (
|
||||
GenericSessionContext,
|
||||
ManagementSessionContext,
|
||||
SessionContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .channel import Channel
|
||||
|
||||
|
||||
def create_new_session(channel_ctx: Channel) -> SessionContext:
|
||||
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
|
||||
return SessionContext(channel_ctx, session_cache)
|
||||
|
||||
|
||||
def create_new_management_session(
|
||||
channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
|
||||
) -> ManagementSessionContext:
|
||||
return ManagementSessionContext(channel_ctx, session_id)
|
||||
|
||||
|
||||
def get_session_from_cache(
|
||||
channel_ctx: Channel, session_id: int
|
||||
) -> GenericSessionContext | None:
|
||||
cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id)
|
||||
for s in cached_sessions:
|
||||
print(s, s.channel_id, int.from_bytes(s.session_id, "big"))
|
||||
if (
|
||||
s.channel_id == channel_ctx.channel_id
|
||||
and int.from_bytes(s.session_id, "big") == session_id
|
||||
):
|
||||
return SessionContext(channel_ctx, s)
|
||||
return None
|
187
core/src/trezor/wire/thp/thp_main.py
Normal file
187
core/src/trezor/wire/thp/thp_main.py
Normal file
@ -0,0 +1,187 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import io, log, loop, utils
|
||||
|
||||
from . import (
|
||||
CHANNEL_ALLOCATION_REQ,
|
||||
CODEC_V1,
|
||||
ChannelState,
|
||||
PacketHeader,
|
||||
ThpError,
|
||||
ThpErrorType,
|
||||
channel_manager,
|
||||
checksum,
|
||||
get_channel_allocation_response,
|
||||
writer,
|
||||
)
|
||||
from .channel import Channel
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .writer import (
|
||||
INIT_HEADER_LENGTH,
|
||||
MAX_PAYLOAD_LEN,
|
||||
PACKET_LENGTH,
|
||||
write_payload_to_wire_and_add_checksum,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
_CID_REQ_PAYLOAD_LENGTH = const(12)
|
||||
_READ_BUFFER: bytearray
|
||||
_WRITE_BUFFER: bytearray
|
||||
_CHANNELS: dict[int, Channel] = {}
|
||||
|
||||
|
||||
def set_read_buffer(buffer: bytearray) -> None:
|
||||
global _READ_BUFFER
|
||||
_READ_BUFFER = buffer
|
||||
|
||||
|
||||
def set_write_buffer(buffer: bytearray) -> None:
|
||||
global _WRITE_BUFFER
|
||||
_WRITE_BUFFER = buffer
|
||||
|
||||
|
||||
def get_raw_read_buffer() -> bytearray:
|
||||
global _READ_BUFFER
|
||||
return _READ_BUFFER
|
||||
|
||||
|
||||
def get_raw_write_buffer() -> bytearray:
|
||||
global _WRITE_BUFFER
|
||||
return _WRITE_BUFFER
|
||||
|
||||
|
||||
async def thp_main_loop(iface: WireInterface) -> None:
|
||||
global _CHANNELS
|
||||
global _READ_BUFFER
|
||||
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
|
||||
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "thp_main_loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
if ctrl_byte == CODEC_V1:
|
||||
await _handle_codec_v1(iface, packet)
|
||||
continue
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
continue
|
||||
|
||||
if cid in _CHANNELS:
|
||||
await _handle_allocated(iface, cid, packet)
|
||||
else:
|
||||
await _handle_unallocated(iface, cid)
|
||||
|
||||
except ThpError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
|
||||
|
||||
async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
|
||||
# If the received packet is not an initial codec_v1 packet, do not send error message
|
||||
if not packet[1:3] == b"##":
|
||||
return
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received codec_v1 message, returning error")
|
||||
error_message = _get_codec_v1_error_message()
|
||||
await writer.write_packet_to_wire(iface, error_message)
|
||||
|
||||
|
||||
async def _handle_broadcast(
|
||||
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
|
||||
) -> None:
|
||||
global _READ_BUFFER
|
||||
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received valid message on the broadcast channel")
|
||||
|
||||
length, nonce = ustruct.unpack(">H8s", packet[3:])
|
||||
payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
|
||||
if not checksum.is_valid(
|
||||
payload[-4:],
|
||||
packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
|
||||
):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
|
||||
cid = int.from_bytes(new_channel.channel_id, "big")
|
||||
_CHANNELS[cid] = new_channel
|
||||
|
||||
response_data = get_channel_allocation_response(
|
||||
nonce, new_channel.channel_id, iface
|
||||
)
|
||||
response_header = PacketHeader.get_channel_allocation_response_header(
|
||||
len(response_data) + CHECKSUM_LENGTH,
|
||||
)
|
||||
if __debug__:
|
||||
log.debug(__name__, "New channel allocated with id %d", cid)
|
||||
|
||||
await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
|
||||
|
||||
|
||||
async def _handle_allocated(
|
||||
iface: WireInterface, cid: int, packet: utils.BufferType
|
||||
) -> None:
|
||||
channel = _CHANNELS[cid]
|
||||
if channel is None:
|
||||
await _handle_unallocated(iface, cid)
|
||||
raise ThpError("Invalid state of a channel")
|
||||
if channel.iface is not iface:
|
||||
# TODO send error message to wire
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
x = channel.receive_packet(packet)
|
||||
if x is not None:
|
||||
await x
|
||||
|
||||
|
||||
async def _handle_unallocated(iface: WireInterface, cid: int) -> None:
|
||||
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
|
||||
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
||||
await write_payload_to_wire_and_add_checksum(iface, header, data)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int,
|
||||
existing_buffer: utils.BufferType,
|
||||
max_length: int = MAX_PAYLOAD_LEN,
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
return _try_allocate_new_buffer(payload_length)
|
||||
return _reuse_existing_buffer(payload_length, existing_buffer)
|
||||
|
||||
|
||||
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(PACKET_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
|
||||
def _reuse_existing_buffer(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
def _get_codec_v1_error_message() -> bytes:
|
||||
# Codec_v1 magic constant "?##" + Failure message type + msg_size
|
||||
# + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B
|
||||
ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
return ERROR_MSG
|
54
core/src/trezor/wire/thp/transmission_loop.py
Normal file
54
core/src/trezor/wire/thp/transmission_loop.py
Normal file
@ -0,0 +1,54 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import loop
|
||||
|
||||
from .writer import write_payload_to_wire_and_add_checksum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PacketHeader
|
||||
from .channel import Channel
|
||||
|
||||
MAX_RETRANSMISSION_COUNT = const(50)
|
||||
MIN_RETRANSMISSION_COUNT = const(2)
|
||||
|
||||
|
||||
class TransmissionLoop:
|
||||
|
||||
def __init__(
|
||||
self, channel: Channel, header: PacketHeader, transport_payload: bytes
|
||||
) -> None:
|
||||
self.channel: Channel = channel
|
||||
self.header: PacketHeader = header
|
||||
self.transport_payload: bytes = transport_payload
|
||||
self.wait_task: loop.spawn | None = None
|
||||
self.min_retransmisson_count_achieved: bool = False
|
||||
|
||||
async def start(
|
||||
self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT
|
||||
) -> None:
|
||||
self.min_retransmisson_count_achieved = False
|
||||
for i in range(max_retransmission_count):
|
||||
if i >= MIN_RETRANSMISSION_COUNT:
|
||||
self.min_retransmisson_count_achieved = True
|
||||
await write_payload_to_wire_and_add_checksum(
|
||||
self.channel.iface, self.header, self.transport_payload
|
||||
)
|
||||
self.wait_task = loop.spawn(self._wait(i))
|
||||
try:
|
||||
await self.wait_task
|
||||
except loop.TaskClosed:
|
||||
self.wait_task = None
|
||||
break
|
||||
|
||||
def stop_immediately(self) -> None:
|
||||
if self.wait_task is not None:
|
||||
self.wait_task.close()
|
||||
self.wait_task = None
|
||||
|
||||
async def _wait(self, counter: int = 0) -> None:
|
||||
timeout_ms = round(10200 - 1010000 / (counter + 100))
|
||||
await loop.sleep(timeout_ms)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop_immediately()
|
92
core/src/trezor/wire/thp/writer.py
Normal file
92
core/src/trezor/wire/thp/writer.py
Normal file
@ -0,0 +1,92 @@
|
||||
from micropython import const
|
||||
from trezorcrypto import crc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import io, log, loop, utils
|
||||
|
||||
from . import PacketHeader
|
||||
|
||||
INIT_HEADER_LENGTH = const(5)
|
||||
CONT_HEADER_LENGTH = const(3)
|
||||
PACKET_LENGTH = const(64)
|
||||
CHECKSUM_LENGTH = const(4)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
MESSAGE_TYPE_LENGTH = const(2)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Awaitable, Sequence
|
||||
|
||||
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
iface: WireInterface, header: PacketHeader, transport_payload: bytes
|
||||
) -> Awaitable[None]:
|
||||
header_checksum: int = crc.crc32(header.to_bytes())
|
||||
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
|
||||
CHECKSUM_LENGTH, "big"
|
||||
)
|
||||
data = (transport_payload, checksum)
|
||||
return write_payloads_to_wire(iface, header, data)
|
||||
|
||||
|
||||
async def write_payloads_to_wire(
|
||||
iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
|
||||
) -> None:
|
||||
n_of_data = len(data)
|
||||
total_length = sum(len(item) for item in data)
|
||||
|
||||
current_data_idx = 0
|
||||
current_data_offset = 0
|
||||
|
||||
packet = bytearray(PACKET_LENGTH)
|
||||
header.pack_to_init_buffer(packet)
|
||||
packet_offset: int = INIT_HEADER_LENGTH
|
||||
packet_number = 0
|
||||
nwritten = 0
|
||||
while nwritten < total_length:
|
||||
if packet_number == 1:
|
||||
header.pack_to_cont_buffer(packet)
|
||||
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
|
||||
packet[:] = bytearray(PACKET_LENGTH)
|
||||
header.pack_to_cont_buffer(packet)
|
||||
while True:
|
||||
n = utils.memcpy(
|
||||
packet, packet_offset, data[current_data_idx], current_data_offset
|
||||
)
|
||||
packet_offset += n
|
||||
current_data_offset += n
|
||||
nwritten += n
|
||||
|
||||
if packet_offset < PACKET_LENGTH:
|
||||
current_data_idx += 1
|
||||
current_data_offset = 0
|
||||
if current_data_idx >= n_of_data:
|
||||
break
|
||||
elif packet_offset == PACKET_LENGTH:
|
||||
break
|
||||
else:
|
||||
raise Exception("Should not happen!!!")
|
||||
packet_number += 1
|
||||
packet_offset = CONT_HEADER_LENGTH
|
||||
|
||||
# write packet to wire (in-lined)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||
)
|
||||
written_by_iface: int = 0
|
||||
while written_by_iface < len(packet):
|
||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
written_by_iface = iface.write(packet)
|
||||
|
||||
|
||||
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
||||
while True:
|
||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||
)
|
||||
n_written = iface.write(packet)
|
||||
if n_written == len(packet):
|
||||
return
|
@ -1,9 +1,9 @@
|
||||
import utime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache
|
||||
import storage.cache_common as cache_common
|
||||
from trezor import log, loop
|
||||
from trezor.enums import MessageType
|
||||
from trezor.enums import MessageType, ThpMessageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
@ -17,9 +17,14 @@ if __debug__:
|
||||
|
||||
from trezor import utils
|
||||
|
||||
if utils.USE_THP:
|
||||
protocol_specific = ThpMessageType.ThpCreateNewSession
|
||||
else:
|
||||
protocol_specific = MessageType.Initialize
|
||||
|
||||
|
||||
ALLOW_WHILE_LOCKED = (
|
||||
MessageType.Initialize,
|
||||
protocol_specific,
|
||||
MessageType.EndSession,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.Cancel,
|
||||
@ -153,7 +158,7 @@ def close_others() -> None:
|
||||
if not task.is_running():
|
||||
task.close()
|
||||
|
||||
storage.cache.homescreen_shown = None
|
||||
cache_common.homescreen_shown = None
|
||||
|
||||
# if tasks were running, closing the last of them will run start_default
|
||||
|
||||
@ -211,11 +216,11 @@ class IdleTimer:
|
||||
time and saves it to storage.cache. This is done to avoid losing an
|
||||
active timer when workflow restart happens and tasks are lost.
|
||||
"""
|
||||
if _restore_from_cache and storage.cache.autolock_last_touch is not None:
|
||||
now = storage.cache.autolock_last_touch
|
||||
if _restore_from_cache and cache_common.autolock_last_touch is not None:
|
||||
now = cache_common.autolock_last_touch
|
||||
else:
|
||||
now = utime.ticks_ms()
|
||||
storage.cache.autolock_last_touch = now
|
||||
cache_common.autolock_last_touch = now
|
||||
|
||||
for callback, task in self.tasks.items():
|
||||
timeout_us = self.timeouts[callback]
|
||||
|
17
core/tests/mock_wire_interface.py
Normal file
17
core/tests/mock_wire_interface.py
Normal file
@ -0,0 +1,17 @@
|
||||
from trezor.loop import wait
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
42
core/tests/myTests.sh
Executable file
42
core/tests/myTests.sh
Executable file
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
declare -a results
|
||||
declare -i passed=0 failed=0 exit_code=0
|
||||
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
|
||||
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
|
||||
print_summary() {
|
||||
echo
|
||||
echo 'Summary:'
|
||||
echo '-------------------'
|
||||
printf '%b\n' "${results[@]}"
|
||||
if [ $exit_code == 0 ]; then
|
||||
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
|
||||
else
|
||||
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
|
||||
fi
|
||||
}
|
||||
|
||||
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
|
||||
|
||||
cd $(dirname $0)
|
||||
|
||||
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
|
||||
|
||||
declare -i num_of_tests=${#tests[@]}
|
||||
|
||||
for test_case in ${tests[@]}; do
|
||||
echo ${MICROPYTHON}
|
||||
echo ${test_case}
|
||||
echo
|
||||
if $MICROPYTHON $test_case; then
|
||||
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
|
||||
((passed++))
|
||||
else
|
||||
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
|
||||
((failed++))
|
||||
exit_code=1
|
||||
fi
|
||||
done
|
||||
|
||||
print_summary
|
||||
exit $exit_code
|
@ -1,4 +1,4 @@
|
||||
from common import H_, await_result, unittest # isort:skip
|
||||
from common import * # isort:skip
|
||||
|
||||
import storage.cache
|
||||
from trezor import wire
|
||||
@ -11,15 +11,35 @@ from trezor.messages import (
|
||||
TxInput,
|
||||
TxOutput,
|
||||
)
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
|
||||
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
|
||||
from apps.bitcoin.sign_tx.bitcoin import Bitcoin
|
||||
from apps.bitcoin.sign_tx.tx_info import TxInfo
|
||||
from apps.common import coins
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
else:
|
||||
import storage.cache_codec
|
||||
|
||||
class TestApprover(unittest.TestCase):
|
||||
if utils.USE_THP:
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
thp_common.prepare_context()
|
||||
super().__init__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(self):
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
self.coin = coins.by_name("Bitcoin")
|
||||
self.fee_rate_percent = 0.3
|
||||
@ -47,7 +67,8 @@ class TestApprover(unittest.TestCase):
|
||||
coin_name=self.coin.coin_name,
|
||||
script_type=InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
storage.cache.start_session()
|
||||
if not utils.USE_THP:
|
||||
storage.cache_codec.start_session()
|
||||
|
||||
def make_coinjoin_request(self, inputs):
|
||||
return CoinJoinRequest(
|
||||
|
@ -1,16 +1,36 @@
|
||||
from common import H_, unittest # isort:skip
|
||||
from common import * # isort:skip
|
||||
|
||||
import storage.cache
|
||||
from trezor.enums import InputScriptType
|
||||
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
|
||||
from trezor.wire import context
|
||||
|
||||
from apps.bitcoin.authorization import CoinJoinAuthorization
|
||||
from apps.common import coins
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
|
||||
_ROUND_ID_LEN = 32
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
else:
|
||||
import storage.cache_codec
|
||||
|
||||
|
||||
class TestAuthorization(unittest.TestCase):
|
||||
if utils.USE_THP:
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
thp_common.prepare_context()
|
||||
super().__init__()
|
||||
|
||||
else:
|
||||
|
||||
def __init__(self):
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
coin = coins.by_name("Bitcoin")
|
||||
|
||||
@ -26,7 +46,8 @@ class TestAuthorization(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.authorization = CoinJoinAuthorization(self.msg_auth)
|
||||
storage.cache.start_session()
|
||||
if not utils.USE_THP:
|
||||
storage.cache_codec.start_session()
|
||||
|
||||
def test_ownership_proof_account_depth_mismatch(self):
|
||||
# Account depth mismatch.
|
||||
|
@ -1,230 +1,519 @@
|
||||
from common import * # isort:skip
|
||||
from common import * # isort:skip # noqa: F403
|
||||
|
||||
from mock_storage import mock_storage
|
||||
from storage import cache
|
||||
from storage import cache, cache_codec
|
||||
from trezor.messages import EndSession, Initialize
|
||||
|
||||
from apps.base import handle_EndSession, handle_Initialize
|
||||
from apps.base import handle_EndSession
|
||||
from trezor.wire.codec.codec_context import CodecContext
|
||||
|
||||
KEY = 0
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
from mock_wire_interface import MockHID
|
||||
from storage import cache_thp
|
||||
from trezor.wire.thp import ChannelState
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
# Function moved from cache.py, as it was not used there
|
||||
def is_session_started() -> bool:
|
||||
return cache._active_session_idx is not None
|
||||
_PROTOCOL_CACHE = cache_thp
|
||||
|
||||
else:
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
def is_session_started() -> bool:
|
||||
return cache_codec.get_active_session() is not None
|
||||
|
||||
def get_active_session():
|
||||
return cache_codec.get_active_session()
|
||||
|
||||
|
||||
class TestStorageCache(unittest.TestCase):
|
||||
class TestStorageCache(unittest.TestCase): # noqa: F405
|
||||
def setUp(self):
|
||||
cache.clear_all()
|
||||
|
||||
def test_start_session(self):
|
||||
session_id_a = cache.start_session()
|
||||
self.assertIsNotNone(session_id_a)
|
||||
session_id_b = cache.start_session()
|
||||
self.assertNotEqual(session_id_a, session_id_b)
|
||||
if utils.USE_THP:
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
cache.set(KEY, "something")
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
cache.get(KEY)
|
||||
def __init__(self):
|
||||
thp_common.suppres_debug_log()
|
||||
# xthp_common.prepare_context()
|
||||
# config.init()
|
||||
super().__init__()
|
||||
|
||||
def test_end_session(self):
|
||||
session_id = cache.start_session()
|
||||
self.assertTrue(is_session_started())
|
||||
cache.set(KEY, b"A")
|
||||
cache.end_current_session()
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
cache.clear_all()
|
||||
|
||||
# ending an ended session should be a no-op
|
||||
cache.end_current_session()
|
||||
self.assertFalse(is_session_started())
|
||||
def test_new_channel_and_session(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_id_a = cache.start_session(session_id)
|
||||
# original session no longer exists
|
||||
self.assertNotEqual(session_id_a, session_id)
|
||||
# original session data no longer exists
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
# Assert that channel is created without any sessions
|
||||
self.assertEqual(len(channel.sessions), 0)
|
||||
|
||||
# create a new session
|
||||
session_id_b = cache.start_session()
|
||||
# switch back to original session
|
||||
session_id = cache.start_session(session_id_a)
|
||||
self.assertEqual(session_id, session_id_a)
|
||||
# end original session
|
||||
cache.end_current_session()
|
||||
# switch back to B
|
||||
session_id = cache.start_session(session_id_b)
|
||||
self.assertEqual(session_id, session_id_b)
|
||||
cid_1 = channel.channel_id
|
||||
session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1 = SessionContext(channel, session_cache_1)
|
||||
self.assertEqual(session_1.channel_id, cid_1)
|
||||
|
||||
def test_session_queue(self):
|
||||
session_id = cache.start_session()
|
||||
self.assertEqual(cache.start_session(session_id), session_id)
|
||||
cache.set(KEY, b"A")
|
||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
||||
cache.start_session()
|
||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2 = SessionContext(channel, session_cache_2)
|
||||
self.assertEqual(session_2.channel_id, cid_1)
|
||||
self.assertEqual(session_1.channel_id, session_2.channel_id)
|
||||
self.assertNotEqual(session_1.session_id, session_2.session_id)
|
||||
|
||||
def test_get_set(self):
|
||||
session_id1 = cache.start_session()
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
channel_2 = thp_common.get_new_channel(self.interface)
|
||||
cid_2 = channel_2.channel_id
|
||||
self.assertNotEqual(cid_1, cid_2)
|
||||
|
||||
session_id2 = cache.start_session()
|
||||
cache.set(KEY, b"world")
|
||||
self.assertEqual(cache.get(KEY), b"world")
|
||||
session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
|
||||
session_3 = SessionContext(channel_2, session_cache_3)
|
||||
self.assertEqual(session_3.channel_id, cid_2)
|
||||
|
||||
cache.start_session(session_id2)
|
||||
self.assertEqual(cache.get(KEY), b"world")
|
||||
cache.start_session(session_id1)
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
# Sessions 1 and 3 should have different channel_id, but the same session_id
|
||||
self.assertNotEqual(session_1.channel_id, session_3.channel_id)
|
||||
self.assertEqual(session_1.session_id, session_3.session_id)
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
cache.get(KEY)
|
||||
self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
|
||||
self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
|
||||
|
||||
def test_get_set_int(self):
|
||||
session_id1 = cache.start_session()
|
||||
cache.set_int(KEY, 1234)
|
||||
self.assertEqual(cache.get_int(KEY), 1234)
|
||||
# Check that session data IS in cache for created sessions ONLY
|
||||
for i in range(3):
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
|
||||
self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
|
||||
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
|
||||
self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
|
||||
self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
|
||||
self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
|
||||
|
||||
session_id2 = cache.start_session()
|
||||
cache.set_int(KEY, 5678)
|
||||
self.assertEqual(cache.get_int(KEY), 5678)
|
||||
# Check that session data IS NOT in cache after cache.clear_all()
|
||||
cache.clear_all()
|
||||
for session in cache_thp._SESSIONS:
|
||||
self.assertEqual(session.channel_id, b"")
|
||||
self.assertEqual(session.session_id, b"")
|
||||
self.assertEqual(session.last_usage, 0)
|
||||
self.assertEqual(session.state, b"\x00")
|
||||
|
||||
cache.start_session(session_id2)
|
||||
self.assertEqual(cache.get_int(KEY), 5678)
|
||||
cache.start_session(session_id1)
|
||||
self.assertEqual(cache.get_int(KEY), 1234)
|
||||
def test_channel_capacity_in_cache(self):
|
||||
self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
|
||||
channels = []
|
||||
for i in range(cache_thp._MAX_CHANNELS_COUNT):
|
||||
channels.append(thp_common.get_new_channel(self.interface))
|
||||
channel_ids = [channel.channel_cache.channel_id for channel in channels]
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
cache.get_int(KEY)
|
||||
# Assert that each channel_id is unique and that cache and list of channels
|
||||
# have the same "channels" on the same indexes
|
||||
for i in range(len(channel_ids)):
|
||||
self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
|
||||
for j in range(i + 1, len(channel_ids)):
|
||||
self.assertNotEqual(channel_ids[i], channel_ids[j])
|
||||
|
||||
def test_delete(self):
|
||||
session_id1 = cache.start_session()
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
cache.delete(KEY)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
# Create a new channel that is over the capacity
|
||||
new_channel = thp_common.get_new_channel(self.interface)
|
||||
for c in channels:
|
||||
self.assertNotEqual(c.channel_id, new_channel.channel_id)
|
||||
|
||||
cache.set(KEY, b"hello")
|
||||
cache.start_session()
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
cache.delete(KEY)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
# Test that the oldest (least used) channel was replaced (_CHANNELS[0])
|
||||
self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
|
||||
self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
|
||||
|
||||
cache.start_session(session_id1)
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
# Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
|
||||
# assert that it is not replaced when creating a new channel
|
||||
cache_thp.update_channel_last_used(channel_ids[1])
|
||||
new_new_channel = thp_common.get_new_channel(self.interface)
|
||||
self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
|
||||
|
||||
def test_decorators(self):
|
||||
run_count = 0
|
||||
cache.start_session()
|
||||
# Assert that it was in fact the _CHANNEL[2] that was replaced
|
||||
self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
|
||||
self.assertEqual(
|
||||
cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
|
||||
)
|
||||
|
||||
@cache.stored(KEY)
|
||||
def func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b"foo"
|
||||
def test_session_capacity_in_cache(self):
|
||||
self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
|
||||
channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
|
||||
channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
|
||||
|
||||
# cache is empty
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
self.assertEqual(run_count, 0)
|
||||
self.assertEqual(func(), b"foo")
|
||||
# function was run
|
||||
self.assertEqual(run_count, 1)
|
||||
self.assertEqual(cache.get(KEY), b"foo")
|
||||
# function does not run again but returns cached value
|
||||
self.assertEqual(func(), b"foo")
|
||||
self.assertEqual(run_count, 1)
|
||||
sesions_A = []
|
||||
cid = []
|
||||
sid = []
|
||||
for i in range(3):
|
||||
sesions_A.append(cache_thp.get_new_session(channel_cache_A))
|
||||
cid.append(sesions_A[i].channel_id)
|
||||
sid.append(sesions_A[i].session_id)
|
||||
|
||||
@cache.stored_async(KEY)
|
||||
async def async_func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b"bar"
|
||||
sessions_B = []
|
||||
for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
|
||||
sessions_B.append(cache_thp.get_new_session(channel_cache_B))
|
||||
|
||||
# cache is still full
|
||||
self.assertEqual(await_result(async_func()), b"foo")
|
||||
self.assertEqual(run_count, 1)
|
||||
for i in range(3):
|
||||
self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
|
||||
self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
|
||||
self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
|
||||
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
|
||||
self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
|
||||
|
||||
cache.start_session()
|
||||
self.assertEqual(await_result(async_func()), b"bar")
|
||||
self.assertEqual(run_count, 2)
|
||||
# awaitable is also run only once
|
||||
self.assertEqual(await_result(async_func()), b"bar")
|
||||
self.assertEqual(run_count, 2)
|
||||
# Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
|
||||
new_session = cache_thp.get_new_session(channel_cache_B)
|
||||
self.assertEqual(new_session, cache_thp._SESSIONS[0])
|
||||
self.assertNotEqual(new_session.channel_id, cid[0])
|
||||
self.assertNotEqual(new_session.session_id, sid[0])
|
||||
|
||||
def test_empty_value(self):
|
||||
cache.start_session()
|
||||
# Assert that updating "last used" for session on channel A increases also
|
||||
# the "last usage" of channel A.
|
||||
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
|
||||
cache_thp.update_session_last_used(
|
||||
channel_cache_A.channel_id, sesions_A[1].session_id
|
||||
)
|
||||
self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
|
||||
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"")
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
new_new_session = cache_thp.get_new_session(channel_cache_B)
|
||||
|
||||
cache.delete(KEY)
|
||||
run_count = 0
|
||||
# Assert that creating a new session on channel B shifts the "last usage" again
|
||||
# and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
|
||||
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
|
||||
self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
|
||||
self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
|
||||
self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
|
||||
|
||||
@cache.stored(KEY)
|
||||
def func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b""
|
||||
def test_clear(self):
|
||||
channel_A = thp_common.get_new_channel(self.interface)
|
||||
channel_B = thp_common.get_new_channel(self.interface)
|
||||
cid_A = channel_A.channel_id
|
||||
cid_B = channel_B.channel_id
|
||||
sessions = []
|
||||
|
||||
self.assertEqual(func(), b"")
|
||||
# function gets called once
|
||||
self.assertEqual(run_count, 1)
|
||||
self.assertEqual(func(), b"")
|
||||
# function is not called for a second time
|
||||
self.assertEqual(run_count, 1)
|
||||
for i in range(3):
|
||||
sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
|
||||
sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
|
||||
|
||||
@mock_storage
|
||||
def test_Initialize(self):
|
||||
def call_Initialize(**kwargs):
|
||||
msg = Initialize(**kwargs)
|
||||
return await_result(handle_Initialize(msg))
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
|
||||
|
||||
# calling Initialize without an ID allocates a new one
|
||||
session_id = cache.start_session()
|
||||
features = call_Initialize()
|
||||
self.assertNotEqual(session_id, features.session_id)
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
|
||||
|
||||
# calling Initialize with the current ID does not allocate a new one
|
||||
features = call_Initialize(session_id=session_id)
|
||||
self.assertEqual(session_id, features.session_id)
|
||||
# Assert that clearing of channel A works
|
||||
self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
|
||||
self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
|
||||
self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
|
||||
|
||||
# store "hello"
|
||||
cache.set(KEY, b"hello")
|
||||
# check that it is cleared
|
||||
features = call_Initialize()
|
||||
session_id = features.session_id
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
# store "hello" again
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
channel_A.clear()
|
||||
|
||||
# supplying a different session ID starts a new cache
|
||||
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
self.assertEqual(channel_A.channel_cache.channel_id, b"")
|
||||
self.assertEqual(channel_A.channel_cache.last_usage, 0)
|
||||
self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
|
||||
|
||||
# but resuming a session loads the previous one
|
||||
call_Initialize(session_id=session_id)
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
# Assert that clearing channel A also cleared all its sessions
|
||||
for i in range(3):
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
|
||||
|
||||
def test_EndSession(self):
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
cache.start_session()
|
||||
self.assertTrue(is_session_started())
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
await_result(handle_EndSession(EndSession()))
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
|
||||
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
|
||||
|
||||
cache.clear_all()
|
||||
for session in cache_thp._SESSIONS:
|
||||
self.assertEqual(session.last_usage, 0)
|
||||
self.assertEqual(session.channel_id, b"")
|
||||
for channel in cache_thp._CHANNELS:
|
||||
self.assertEqual(channel.channel_id, b"")
|
||||
self.assertEqual(channel.last_usage, 0)
|
||||
self.assertEqual(
|
||||
cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
|
||||
)
|
||||
|
||||
def test_get_set(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1.set(KEY, b"hello")
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.set(KEY, b"world")
|
||||
self.assertEqual(session_2.get(KEY), b"world")
|
||||
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
|
||||
def test_get_set_int(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_1.set_int(KEY, 1234)
|
||||
|
||||
self.assertEqual(session_1.get_int(KEY), 1234)
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.set_int(KEY, 5678)
|
||||
self.assertEqual(session_2.get_int(KEY), 5678)
|
||||
|
||||
self.assertEqual(session_1.get_int(KEY), 1234)
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(session_1.get_int(KEY))
|
||||
self.assertIsNone(session_2.get_int(KEY))
|
||||
|
||||
def test_get_set_bool(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
with self.assertRaises(AssertionError) as e:
|
||||
session_1.set_bool(KEY, True)
|
||||
self.assertEqual(e.value.value, "Field does not have zero length!")
|
||||
|
||||
# Change length of first session field to 0 so that the length check passes
|
||||
session_1.fields = (0,) + session_1.fields[1:]
|
||||
|
||||
# with self.assertRaises(AssertionError) as e:
|
||||
session_1.set_bool(KEY, True)
|
||||
self.assertEqual(session_1.get_bool(KEY), True)
|
||||
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
|
||||
session_2.set_bool(KEY, False)
|
||||
self.assertEqual(session_2.get_bool(KEY), False)
|
||||
|
||||
self.assertEqual(session_1.get_bool(KEY), True)
|
||||
|
||||
cache.clear_all()
|
||||
|
||||
# Default value is False
|
||||
self.assertFalse(session_1.get_bool(KEY))
|
||||
self.assertFalse(session_2.get_bool(KEY))
|
||||
|
||||
def test_delete(self):
|
||||
channel = thp_common.get_new_channel(self.interface)
|
||||
session_1 = cache_thp.get_new_session(channel.channel_cache)
|
||||
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
session_1.set(KEY, b"hello")
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
session_1.delete(KEY)
|
||||
self.assertIsNone(session_1.get(KEY))
|
||||
|
||||
session_1.set(KEY, b"hello")
|
||||
session_2 = cache_thp.get_new_session(channel.channel_cache)
|
||||
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
session_2.set(KEY, b"hello")
|
||||
self.assertEqual(session_2.get(KEY), b"hello")
|
||||
session_2.delete(KEY)
|
||||
self.assertIsNone(session_2.get(KEY))
|
||||
|
||||
self.assertEqual(session_1.get(KEY), b"hello")
|
||||
|
||||
else:
|
||||
|
||||
def __init__(self):
|
||||
# Context is needed to test decorators and handleInitialize
|
||||
# It allows access to codec cache from different parts of the code
|
||||
from trezor.wire import context
|
||||
|
||||
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||
super().__init__()
|
||||
|
||||
def test_start_session(self):
|
||||
session_id_a = cache_codec.start_session()
|
||||
self.assertIsNotNone(session_id_a)
|
||||
session_id_b = cache_codec.start_session()
|
||||
self.assertNotEqual(session_id_a, session_id_b)
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(get_active_session())
|
||||
for session in cache_codec._SESSIONS:
|
||||
self.assertEqual(session.session_id, b"")
|
||||
self.assertEqual(session.last_usage, 0)
|
||||
|
||||
def test_end_session(self):
|
||||
session_id = cache_codec.start_session()
|
||||
self.assertTrue(is_session_started())
|
||||
get_active_session().set(KEY, b"A")
|
||||
cache_codec.end_current_session()
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertIsNone(get_active_session())
|
||||
|
||||
# ending an ended session should be a no-op
|
||||
cache_codec.end_current_session()
|
||||
self.assertFalse(is_session_started())
|
||||
|
||||
session_id_a = cache_codec.start_session(session_id)
|
||||
# original session no longer exists
|
||||
self.assertNotEqual(session_id_a, session_id)
|
||||
# original session data no longer exists
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
|
||||
# create a new session
|
||||
session_id_b = cache_codec.start_session()
|
||||
# switch back to original session
|
||||
session_id = cache_codec.start_session(session_id_a)
|
||||
self.assertEqual(session_id, session_id_a)
|
||||
# end original session
|
||||
cache_codec.end_current_session()
|
||||
# switch back to B
|
||||
session_id = cache_codec.start_session(session_id_b)
|
||||
self.assertEqual(session_id, session_id_b)
|
||||
|
||||
def test_session_queue(self):
|
||||
session_id = cache_codec.start_session()
|
||||
self.assertEqual(cache_codec.start_session(session_id), session_id)
|
||||
get_active_session().set(KEY, b"A")
|
||||
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
|
||||
cache_codec.start_session()
|
||||
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
|
||||
def test_get_set(self):
|
||||
session_id1 = cache_codec.start_session()
|
||||
cache_codec.get_active_session().set(KEY, b"hello")
|
||||
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
|
||||
|
||||
session_id2 = cache_codec.start_session()
|
||||
cache_codec.get_active_session().set(KEY, b"world")
|
||||
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
|
||||
|
||||
cache_codec.start_session(session_id2)
|
||||
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
|
||||
cache_codec.start_session(session_id1)
|
||||
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(cache_codec.get_active_session())
|
||||
|
||||
def test_get_set_int(self):
|
||||
session_id1 = cache_codec.start_session()
|
||||
get_active_session().set_int(KEY, 1234)
|
||||
self.assertEqual(get_active_session().get_int(KEY), 1234)
|
||||
|
||||
session_id2 = cache_codec.start_session()
|
||||
get_active_session().set_int(KEY, 5678)
|
||||
self.assertEqual(get_active_session().get_int(KEY), 5678)
|
||||
|
||||
cache_codec.start_session(session_id2)
|
||||
self.assertEqual(get_active_session().get_int(KEY), 5678)
|
||||
cache_codec.start_session(session_id1)
|
||||
self.assertEqual(get_active_session().get_int(KEY), 1234)
|
||||
|
||||
cache.clear_all()
|
||||
self.assertIsNone(get_active_session())
|
||||
|
||||
def test_delete(self):
|
||||
session_id1 = cache_codec.start_session()
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
get_active_session().set(KEY, b"hello")
|
||||
self.assertEqual(get_active_session().get(KEY), b"hello")
|
||||
get_active_session().delete(KEY)
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
|
||||
get_active_session().set(KEY, b"hello")
|
||||
cache_codec.start_session()
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
get_active_session().set(KEY, b"hello")
|
||||
self.assertEqual(get_active_session().get(KEY), b"hello")
|
||||
get_active_session().delete(KEY)
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
|
||||
cache_codec.start_session(session_id1)
|
||||
self.assertEqual(get_active_session().get(KEY), b"hello")
|
||||
|
||||
def test_decorators(self):
|
||||
run_count = 0
|
||||
cache_codec.start_session()
|
||||
from apps.common.cache import stored
|
||||
|
||||
@stored(KEY)
|
||||
def func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b"foo"
|
||||
|
||||
# cache is empty
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
self.assertEqual(run_count, 0)
|
||||
self.assertEqual(func(), b"foo")
|
||||
# function was run
|
||||
self.assertEqual(run_count, 1)
|
||||
self.assertEqual(get_active_session().get(KEY), b"foo")
|
||||
# function does not run again but returns cached value
|
||||
self.assertEqual(func(), b"foo")
|
||||
self.assertEqual(run_count, 1)
|
||||
|
||||
def test_empty_value(self):
|
||||
cache_codec.start_session()
|
||||
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
get_active_session().set(KEY, b"")
|
||||
self.assertEqual(get_active_session().get(KEY), b"")
|
||||
|
||||
get_active_session().delete(KEY)
|
||||
run_count = 0
|
||||
|
||||
from apps.common.cache import stored
|
||||
|
||||
@stored(KEY)
|
||||
def func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b""
|
||||
|
||||
self.assertEqual(func(), b"")
|
||||
# function gets called once
|
||||
self.assertEqual(run_count, 1)
|
||||
self.assertEqual(func(), b"")
|
||||
# function is not called for a second time
|
||||
self.assertEqual(run_count, 1)
|
||||
|
||||
if not utils.USE_THP:
|
||||
|
||||
@mock_storage
|
||||
def test_Initialize(self):
|
||||
from apps.base import handle_Initialize
|
||||
|
||||
def call_Initialize(**kwargs):
|
||||
msg = Initialize(**kwargs)
|
||||
return await_result(handle_Initialize(msg))
|
||||
|
||||
# calling Initialize without an ID allocates a new one
|
||||
session_id = cache_codec.start_session()
|
||||
features = call_Initialize()
|
||||
self.assertNotEqual(session_id, features.session_id)
|
||||
|
||||
# calling Initialize with the current ID does not allocate a new one
|
||||
features = call_Initialize(session_id=session_id)
|
||||
self.assertEqual(session_id, features.session_id)
|
||||
|
||||
# store "hello"
|
||||
get_active_session().set(KEY, b"hello")
|
||||
# check that it is cleared
|
||||
features = call_Initialize()
|
||||
session_id = features.session_id
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
# store "hello" again
|
||||
get_active_session().set(KEY, b"hello")
|
||||
self.assertEqual(get_active_session().get(KEY), b"hello")
|
||||
|
||||
# supplying a different session ID starts a new session
|
||||
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
|
||||
# but resuming a session loads the previous one
|
||||
call_Initialize(session_id=session_id)
|
||||
self.assertEqual(get_active_session().get(KEY), b"hello")
|
||||
|
||||
def test_EndSession(self):
|
||||
|
||||
self.assertIsNone(get_active_session())
|
||||
cache_codec.start_session()
|
||||
self.assertTrue(is_session_started())
|
||||
self.assertIsNone(get_active_session().get(KEY))
|
||||
await_result(handle_EndSession(EndSession()))
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertIsNone(cache_codec.get_active_session())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -2,27 +2,10 @@ from common import * # isort:skip
|
||||
|
||||
import ustruct
|
||||
|
||||
from mock_wire_interface import MockHID
|
||||
from trezor import io
|
||||
from trezor.loop import wait
|
||||
from trezor.utils import chunks
|
||||
from trezor.wire import codec_v1
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
from trezor.wire.codec import codec_v1
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
|
||||
|
94
core/tests/test_trezor.wire.thp.checksum.py
Normal file
94
core/tests/test_trezor.wire.thp.checksum.py
Normal file
@ -0,0 +1,94 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
if utils.USE_THP:
|
||||
from trezor.wire.thp import checksum
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestTrezorHostProtocolChecksum(unittest.TestCase):
|
||||
vectors_correct = [
|
||||
(
|
||||
b"",
|
||||
b"\x00\x00\x00\x00",
|
||||
),
|
||||
(
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
b"\x19\x0A\x55\xAD",
|
||||
),
|
||||
(
|
||||
bytes("a", "ascii"),
|
||||
b"\xE8\xB7\xBE\x43",
|
||||
),
|
||||
(
|
||||
bytes("abc", "ascii"),
|
||||
b"\x35\x24\x41\xC2",
|
||||
),
|
||||
(
|
||||
bytes("123456789", "ascii"),
|
||||
b"\xCB\xF4\x39\x26",
|
||||
),
|
||||
(
|
||||
bytes(
|
||||
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
|
||||
"ascii",
|
||||
),
|
||||
b"\x7C\xA9\x4A\x72",
|
||||
),
|
||||
(
|
||||
b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
|
||||
b"\x9B\xD3\x66\xAE",
|
||||
),
|
||||
(
|
||||
b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
|
||||
b"\x6B\xA4\xEC\x92",
|
||||
),
|
||||
]
|
||||
vectors_incorrect = [
|
||||
(
|
||||
b"",
|
||||
b"\x00\x00\x00\x00\x00",
|
||||
),
|
||||
(
|
||||
b"",
|
||||
b"",
|
||||
),
|
||||
(
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
b"\x19\x0A\x55\xAE",
|
||||
),
|
||||
(
|
||||
bytes("A", "ascii"),
|
||||
b"\xE8\xB7\xBE\x43",
|
||||
),
|
||||
(
|
||||
bytes("abc ", "ascii"),
|
||||
b"\x35\x24\x41\xC2",
|
||||
),
|
||||
(
|
||||
bytes("1234567890", "ascii"),
|
||||
b"\xCB\xF4\x39\x26",
|
||||
),
|
||||
(
|
||||
bytes(
|
||||
"1234567890123456789012345678901234567890123456789012345678901234567890123456789",
|
||||
"ascii",
|
||||
),
|
||||
b"\x7C\xA9\x4A\x72",
|
||||
),
|
||||
]
|
||||
|
||||
def test_computation(self):
|
||||
for data, chksum in self.vectors_correct:
|
||||
self.assertEqual(checksum.compute(data), chksum)
|
||||
|
||||
def test_validation_correct(self):
|
||||
for data, chksum in self.vectors_correct:
|
||||
self.assertTrue(checksum.is_valid(chksum, data))
|
||||
|
||||
def test_validation_incorrect(self):
|
||||
for data, chksum in self.vectors_incorrect:
|
||||
self.assertFalse(checksum.is_valid(chksum, data))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
66
core/tests/test_trezor.wire.thp.credential_manager.py
Normal file
66
core/tests/test_trezor.wire.thp.credential_manager.py
Normal file
@ -0,0 +1,66 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
from trezor import config
|
||||
from trezor.messages import ThpCredentialMetadata
|
||||
|
||||
from apps.thp import credential_manager
|
||||
|
||||
def _issue_credential(host_name: str, host_static_pubkey: bytes) -> bytes:
|
||||
metadata = ThpCredentialMetadata(host_name=host_name)
|
||||
return credential_manager.issue_credential(host_static_pubkey, metadata)
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestTrezorHostProtocolCredentialManager(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
|
||||
def test_derive_cred_auth_key(self):
|
||||
key1 = credential_manager.derive_cred_auth_key()
|
||||
key2 = credential_manager.derive_cred_auth_key()
|
||||
self.assertEqual(len(key1), 32)
|
||||
self.assertEqual(key1, key2)
|
||||
|
||||
def test_invalidate_cred_auth_key(self):
|
||||
key1 = credential_manager.derive_cred_auth_key()
|
||||
credential_manager.invalidate_cred_auth_key()
|
||||
key2 = credential_manager.derive_cred_auth_key()
|
||||
self.assertNotEqual(key1, key2)
|
||||
|
||||
def test_credentials(self):
|
||||
DUMMY_KEY_1 = b"\x00\x00"
|
||||
DUMMY_KEY_2 = b"\xff\xff"
|
||||
HOST_NAME_1 = "host_name"
|
||||
HOST_NAME_2 = "different host_name"
|
||||
|
||||
cred_1 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
|
||||
cred_2 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
|
||||
self.assertEqual(cred_1, cred_2)
|
||||
|
||||
cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1)
|
||||
self.assertNotEqual(cred_1, cred_3)
|
||||
|
||||
self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
|
||||
self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
|
||||
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2))
|
||||
|
||||
credential_manager.invalidate_cred_auth_key()
|
||||
cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
|
||||
self.assertNotEqual(cred_1, cred_4)
|
||||
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
|
||||
self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
|
||||
self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
156
core/tests/test_trezor.wire.thp.crypto.py
Normal file
156
core/tests/test_trezor.wire.thp.crypto.py
Normal file
@ -0,0 +1,156 @@
|
||||
from common import * # isort:skip
|
||||
from trezorcrypto import aesgcm, curve25519
|
||||
|
||||
import storage
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
from trezor.wire.thp import crypto
|
||||
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
|
||||
|
||||
def get_dummy_device_secret():
|
||||
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestTrezorHostProtocolCrypto(unittest.TestCase):
|
||||
if utils.USE_THP:
|
||||
handshake = Handshake()
|
||||
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
|
||||
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
|
||||
vectors_enc = [
|
||||
(
|
||||
key_1,
|
||||
0,
|
||||
b"\x55\x64",
|
||||
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
|
||||
b"e2c9dd152fbee5821ea7",
|
||||
b"10625812de81b14a46b9f1e5100a6d0c",
|
||||
),
|
||||
(
|
||||
key_1,
|
||||
1,
|
||||
b"\x55\x64",
|
||||
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
|
||||
b"79811619ddb07c2b99f8",
|
||||
b"71c6b872cdc499a7e9a3c7441f053214",
|
||||
),
|
||||
(
|
||||
key_1,
|
||||
369,
|
||||
b"\x55\x64",
|
||||
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
|
||||
b"03bd030390f2dfe815a61c2b157a064f",
|
||||
b"c1200f8a7ae9a6d32cef0fff878d55c2",
|
||||
),
|
||||
(
|
||||
key_1,
|
||||
369,
|
||||
b"\x55\x64\x73\x82\x91",
|
||||
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
|
||||
b"03bd030390f2dfe815a61c2b157a064f",
|
||||
b"693ac160cd93a20f7fc255f049d808d0",
|
||||
),
|
||||
]
|
||||
# 0:chaining key, 1:input, 2:output_1, 3:output:2
|
||||
vectors_hkdf = [
|
||||
(
|
||||
crypto.PROTOCOL_NAME,
|
||||
b"\x01\x02",
|
||||
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
|
||||
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
|
||||
),
|
||||
(
|
||||
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
|
||||
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
|
||||
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
|
||||
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
|
||||
),
|
||||
]
|
||||
vectors_iv = [
|
||||
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
|
||||
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
|
||||
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
|
||||
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
|
||||
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
|
||||
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
utils.DISABLE_ENCRYPTION = False
|
||||
|
||||
def test_encryption(self):
|
||||
for v in self.vectors_enc:
|
||||
buffer = bytearray(v[3])
|
||||
tag = crypto.enc(buffer, v[0], v[1], v[2])
|
||||
self.assertEqual(hexlify(buffer), v[4])
|
||||
self.assertEqual(hexlify(tag), v[5])
|
||||
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
|
||||
self.assertEqual(buffer, v[3])
|
||||
|
||||
def test_hkdf(self):
|
||||
for v in self.vectors_hkdf:
|
||||
ck, k = crypto._hkdf(v[0], v[1])
|
||||
self.assertEqual(hexlify(ck), v[2])
|
||||
self.assertEqual(hexlify(k), v[3])
|
||||
|
||||
def test_iv_from_nonce(self):
|
||||
for v in self.vectors_iv:
|
||||
x = v[0]
|
||||
y = x.to_bytes(8, "big")
|
||||
iv = crypto._get_iv_from_nonce(v[0])
|
||||
self.assertEqual(iv, v[1])
|
||||
with self.assertRaises(AssertionError) as e:
|
||||
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
|
||||
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
|
||||
|
||||
def test_incorrect_vectors(self):
|
||||
pass
|
||||
|
||||
def test_th1_crypto(self):
|
||||
storage.device.get_device_secret = get_dummy_device_secret
|
||||
handshake = self.handshake
|
||||
|
||||
host_ephemeral_privkey = curve25519.generate_secret()
|
||||
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
|
||||
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
|
||||
|
||||
def test_th2_crypto(self):
|
||||
handshake = self.handshake
|
||||
|
||||
host_static_privkey = curve25519.generate_secret()
|
||||
host_static_pubkey = curve25519.publickey(host_static_privkey)
|
||||
aes_ctx = aesgcm(handshake.k, IV_2)
|
||||
aes_ctx.auth(handshake.h)
|
||||
encrypted_host_static_pubkey = bytearray(
|
||||
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
|
||||
)
|
||||
|
||||
# Code to encrypt Host's noise encrypted payload correctly:
|
||||
protomsg = bytearray(b"\x10\x02\x10\x03")
|
||||
temp_k = handshake.k
|
||||
temp_h = handshake.h
|
||||
|
||||
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
|
||||
_, temp_k = crypto._hkdf(
|
||||
handshake.ck,
|
||||
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
|
||||
)
|
||||
aes_ctx = aesgcm(temp_k, IV_1)
|
||||
aes_ctx.encrypt_in_place(protomsg)
|
||||
aes_ctx.auth(temp_h)
|
||||
tag = aes_ctx.finish()
|
||||
encrypted_payload = bytearray(protomsg + tag)
|
||||
# end of encrypted payload generation
|
||||
|
||||
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
|
||||
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
377
core/tests/test_trezor.wire.thp.py
Normal file
377
core/tests/test_trezor.wire.thp.py
Normal file
@ -0,0 +1,377 @@
|
||||
from common import * # isort:skip
|
||||
from mock_wire_interface import MockHID
|
||||
from trezor import config, io, protobuf
|
||||
from trezor.crypto.curve import curve25519
|
||||
from trezor.enums import ThpMessageType
|
||||
from trezor.wire.errors import UnexpectedMessage
|
||||
from trezor.wire.protocol_common import Message
|
||||
|
||||
if utils.USE_THP:
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import thp_common
|
||||
from storage import cache_thp
|
||||
from storage.cache_common import (
|
||||
CHANNEL_HANDSHAKE_HASH,
|
||||
CHANNEL_KEY_RECEIVE,
|
||||
CHANNEL_KEY_SEND,
|
||||
CHANNEL_NONCE_RECEIVE,
|
||||
CHANNEL_NONCE_SEND,
|
||||
)
|
||||
from trezor.crypto import elligator2
|
||||
from trezor.enums import ThpPairingMethod
|
||||
from trezor.messages import (
|
||||
ThpCodeEntryChallenge,
|
||||
ThpCodeEntryCpaceHost,
|
||||
ThpCodeEntryTag,
|
||||
ThpCredentialRequest,
|
||||
ThpEndRequest,
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
from trezor.wire.thp import thp_main
|
||||
from trezor.wire.thp import ChannelState, checksum, interface_manager
|
||||
from trezor.wire.thp.crypto import Handshake
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
|
||||
from apps.thp import pairing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
def get_dummy_key() -> bytes:
|
||||
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
|
||||
|
||||
def dummy_encode_iface(iface: WireInterface):
|
||||
return thp_common._MOCK_INTERFACE_HID
|
||||
|
||||
def send_channel_allocation_request(
|
||||
interface: WireInterface, nonce: bytes | None = None
|
||||
) -> bytes:
|
||||
if nonce is None or len(nonce) != 8:
|
||||
nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
|
||||
header = b"\x40\xff\xff\x00\x0c"
|
||||
chksum = checksum.compute(header + nonce)
|
||||
cid_req = header + nonce + chksum
|
||||
gen = thp_main.thp_main_loop(interface)
|
||||
expected_channel_index = cache_thp._get_next_channel_index()
|
||||
gen.send(None)
|
||||
gen.send(cid_req)
|
||||
gen.send(None)
|
||||
response_data = (
|
||||
b"\x0a\x04\x54\x32\x54\x31\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
|
||||
)
|
||||
response_without_crc = (
|
||||
b"\x41\xff\xff\x00\x20"
|
||||
+ nonce
|
||||
+ cache_thp._CHANNELS[expected_channel_index].channel_id
|
||||
+ response_data
|
||||
)
|
||||
chkcsum = checksum.compute(response_without_crc)
|
||||
expected_response = response_without_crc + chkcsum + b"\x00" * 27
|
||||
return expected_response
|
||||
|
||||
def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
|
||||
return int.from_bytes(channel_allocation_response[13:15], "big")
|
||||
|
||||
def get_ack(channel_id: bytes) -> bytes:
|
||||
if len(channel_id) != 2:
|
||||
raise Exception("Channel id should by two bytes long")
|
||||
return (
|
||||
b"\x20"
|
||||
+ channel_id
|
||||
+ b"\x00\x04"
|
||||
+ checksum.compute(b"\x20" + channel_id + b"\x00\x04")
|
||||
+ b"\x00" * 55
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestTrezorHostProtocol(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
interface_manager.encode_iface = dummy_encode_iface
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
buffer = bytearray(64)
|
||||
buffer2 = bytearray(256)
|
||||
thp_main.set_read_buffer(buffer)
|
||||
thp_main.set_write_buffer(buffer2)
|
||||
interface_manager.decode_iface = thp_common.dummy_decode_iface
|
||||
|
||||
def test_codec_message(self):
|
||||
self.assertEqual(len(self.interface.data), 0)
|
||||
gen = thp_main.thp_main_loop(self.interface)
|
||||
gen.send(None)
|
||||
|
||||
# There should be a failiure response to received init packet (starts with "?##")
|
||||
test_codec_message = b"?## Some data"
|
||||
gen.send(test_codec_message)
|
||||
gen.send(None)
|
||||
self.assertEqual(len(self.interface.data), 1)
|
||||
|
||||
expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10"
|
||||
self.assertEqual(
|
||||
self.interface.data[-1][: len(expected_response)], expected_response
|
||||
)
|
||||
|
||||
# There should be no response for continuation packet (starts with "?" only)
|
||||
test_codec_message_2 = b"? Cont packet"
|
||||
gen.send(test_codec_message_2)
|
||||
with self.assertRaises(TypeError) as e:
|
||||
gen.send(None)
|
||||
self.assertEqual(e.value.value, "object with buffer protocol required")
|
||||
self.assertEqual(len(self.interface.data), 1)
|
||||
|
||||
def test_message_on_unallocated_channel(self):
|
||||
gen = thp_main.thp_main_loop(self.interface)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
message_to_channel_789a = (
|
||||
b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
)
|
||||
gen.send(message_to_channel_789a)
|
||||
gen.send(None)
|
||||
unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
self.assertEqual(
|
||||
utils.get_bytes_as_str(self.interface.data[-1]),
|
||||
unallocated_chanel_error_on_channel_789a,
|
||||
)
|
||||
|
||||
def test_channel_allocation(self):
|
||||
self.assertEqual(len(thp_main._CHANNELS), 0)
|
||||
for c in cache_thp._CHANNELS:
|
||||
self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED)
|
||||
|
||||
expected_channel_index = cache_thp._get_next_channel_index()
|
||||
expected_response = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-1], expected_response)
|
||||
|
||||
cid = cache_thp._CHANNELS[expected_channel_index].channel_id
|
||||
self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS)
|
||||
self.assertEqual(len(thp_main._CHANNELS), 1)
|
||||
|
||||
# test channel's default state is TH1:
|
||||
cid = get_channel_id_from_response(self.interface.data[-1])
|
||||
self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
|
||||
|
||||
def test_invalid_encrypted_tag(self):
|
||||
gen = thp_main.thp_main_loop(self.interface)
|
||||
gen.send(None)
|
||||
# prepare 2 new channels
|
||||
expected_response_1 = send_channel_allocation_request(self.interface)
|
||||
expected_response_2 = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-2], expected_response_1)
|
||||
self.assertEqual(self.interface.data[-1], expected_response_2)
|
||||
|
||||
# test invalid encryption tag
|
||||
config.init()
|
||||
config.wipe()
|
||||
cid_1 = get_channel_id_from_response(expected_response_1)
|
||||
channel = thp_main._CHANNELS[cid_1]
|
||||
channel.iface = self.interface
|
||||
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
header = b"\x04" + channel.channel_id + b"\x00\x14"
|
||||
|
||||
tag = b"\x00" * 16
|
||||
chksum = checksum.compute(header + tag)
|
||||
message_with_invalid_tag = header + tag + chksum
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
|
||||
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
|
||||
expected_ack_on_received_message = get_ack(cid_1_bytes)
|
||||
|
||||
gen.send(message_with_invalid_tag)
|
||||
gen.send(None)
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
expected_ack_on_received_message,
|
||||
)
|
||||
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
|
||||
chksum_err = checksum.compute(error_without_crc)
|
||||
gen.send(None)
|
||||
|
||||
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
decryption_failed_error,
|
||||
)
|
||||
|
||||
def test_channel_errors(self):
|
||||
gen = thp_main.thp_main_loop(self.interface)
|
||||
gen.send(None)
|
||||
# prepare 2 new channels
|
||||
expected_response_1 = send_channel_allocation_request(self.interface)
|
||||
expected_response_2 = send_channel_allocation_request(self.interface)
|
||||
self.assertEqual(self.interface.data[-2], expected_response_1)
|
||||
self.assertEqual(self.interface.data[-1], expected_response_2)
|
||||
|
||||
# test invalid encryption tag
|
||||
config.init()
|
||||
config.wipe()
|
||||
cid_1 = get_channel_id_from_response(expected_response_1)
|
||||
channel = thp_main._CHANNELS[cid_1]
|
||||
channel.iface = self.interface
|
||||
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
header = b"\x04" + channel.channel_id + b"\x00\x14"
|
||||
|
||||
tag = b"\x00" * 16
|
||||
chksum = checksum.compute(header + tag)
|
||||
message_with_invalid_tag = header + tag + chksum
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
|
||||
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
|
||||
expected_ack_on_received_message = get_ack(cid_1_bytes)
|
||||
|
||||
gen.send(message_with_invalid_tag)
|
||||
gen.send(None)
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
expected_ack_on_received_message,
|
||||
)
|
||||
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
|
||||
chksum_err = checksum.compute(error_without_crc)
|
||||
gen.send(None)
|
||||
|
||||
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
|
||||
|
||||
self.assertEqual(
|
||||
self.interface.data[-1],
|
||||
decryption_failed_error,
|
||||
)
|
||||
|
||||
# test invalid tag in handshake phase
|
||||
cid_2 = get_channel_id_from_response(expected_response_1)
|
||||
cid_2_bytes = cid_2.to_bytes(2, "big")
|
||||
channel = thp_main._CHANNELS[cid_2]
|
||||
channel.iface = self.interface
|
||||
|
||||
channel.set_channel_state(ChannelState.TH2)
|
||||
|
||||
message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
|
||||
|
||||
# gen.send(message_with_invalid_tag)
|
||||
# gen.send(None)
|
||||
# gen.send(None)
|
||||
# for i in self.interface.data:
|
||||
# print(utils.get_bytes_as_str(i))
|
||||
|
||||
def test_skip_pairing(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
channel = next(iter(thp_main._CHANNELS.values()))
|
||||
channel.selected_pairing_methods = [
|
||||
ThpPairingMethod.NoMethod,
|
||||
ThpPairingMethod.CodeEntry,
|
||||
ThpPairingMethod.NFC_Unidirectional,
|
||||
ThpPairingMethod.QrCode,
|
||||
]
|
||||
pairing_ctx = PairingContext(channel)
|
||||
request_message = ThpStartPairingRequest()
|
||||
channel.set_channel_state(ChannelState.TP1)
|
||||
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
# Teardown: set back initial channel state value
|
||||
channel.set_channel_state(ChannelState.TH1)
|
||||
|
||||
def TODO_test_pairing(self):
|
||||
config.init()
|
||||
config.wipe()
|
||||
cid = get_channel_id_from_response(
|
||||
send_channel_allocation_request(self.interface)
|
||||
)
|
||||
channel = thp_main._CHANNELS[cid]
|
||||
channel.selected_pairing_methods = [
|
||||
ThpPairingMethod.CodeEntry,
|
||||
ThpPairingMethod.NFC_Unidirectional,
|
||||
ThpPairingMethod.QrCode,
|
||||
]
|
||||
pairing_ctx = PairingContext(channel)
|
||||
request_message = ThpStartPairingRequest()
|
||||
with self.assertRaises(UnexpectedMessage) as e:
|
||||
pairing.handle_pairing_request(pairing_ctx, request_message)
|
||||
print(e.value.message)
|
||||
channel.set_channel_state(ChannelState.TP1)
|
||||
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
|
||||
|
||||
channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
|
||||
channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
|
||||
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
|
||||
|
||||
gen.send(None)
|
||||
|
||||
async def _dummy(ctx: PairingContext, expected_types):
|
||||
return await ctx.read([1018, 1024])
|
||||
|
||||
pairing.show_display_data = _dummy
|
||||
|
||||
msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
|
||||
protobuf.encode(buffer, msg_code_entry)
|
||||
code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer)
|
||||
gen.send(code_entry_challenge)
|
||||
|
||||
# tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
|
||||
# tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
|
||||
|
||||
pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
|
||||
generator_host = elligator2.map_to_curve25519(pregenerator_host)
|
||||
cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
|
||||
cpace_host_public_key: bytes = curve25519.multiply(
|
||||
cpace_host_private_key, generator_host
|
||||
)
|
||||
msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
|
||||
|
||||
# msg = ThpQrCodeTag(tag=tag_qrc)
|
||||
# msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
|
||||
|
||||
protobuf.encode(buffer, msg)
|
||||
user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer)
|
||||
gen.send(user_message)
|
||||
|
||||
tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
|
||||
msg = ThpCodeEntryTag(tag=tag_ent)
|
||||
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
|
||||
|
||||
protobuf.encode(buffer, msg)
|
||||
user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer)
|
||||
gen.send(user_message)
|
||||
|
||||
host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
|
||||
msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
|
||||
protobuf.encode(buffer, msg)
|
||||
credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer)
|
||||
gen.send(credential_request)
|
||||
|
||||
msg = ThpEndRequest()
|
||||
|
||||
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
|
||||
protobuf.encode(buffer, msg)
|
||||
end_request = Message(1012, buffer)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(end_request)
|
||||
print("response message:", e.value.value.MESSAGE_NAME)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
151
core/tests/test_trezor.wire.thp.writer.py
Normal file
151
core/tests/test_trezor.wire.thp.writer.py
Normal file
@ -0,0 +1,151 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
from typing import Any, Awaitable
|
||||
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
from mock_wire_interface import MockHID
|
||||
from trezor.wire.thp import writer
|
||||
from trezor.wire.thp import ENCRYPTED, PacketHeader
|
||||
|
||||
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestTrezorHostProtocolWriter(unittest.TestCase):
|
||||
short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
longer_payload_expected = [
|
||||
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
|
||||
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
|
||||
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
|
||||
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
|
||||
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
]
|
||||
eight_longer_payloads_expected = [
|
||||
b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
|
||||
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
|
||||
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
|
||||
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
|
||||
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
|
||||
b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
|
||||
b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
|
||||
b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
|
||||
b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
|
||||
b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
|
||||
b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
|
||||
b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
|
||||
b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
|
||||
b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
|
||||
b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
|
||||
b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
|
||||
b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
|
||||
b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
|
||||
b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
|
||||
b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
|
||||
b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
|
||||
b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
|
||||
b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
|
||||
b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
|
||||
b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
|
||||
b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
|
||||
b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
|
||||
b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
|
||||
b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
|
||||
b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
|
||||
b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
|
||||
b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
|
||||
b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
|
||||
b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
|
||||
]
|
||||
empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
|
||||
longer_payload_with_checksum_expected = [
|
||||
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
|
||||
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
|
||||
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
|
||||
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
|
||||
b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
]
|
||||
|
||||
def await_until_result(self, task: Awaitable) -> Any:
|
||||
with self.assertRaises(StopIteration):
|
||||
while True:
|
||||
task.send(None)
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
|
||||
def test_write_empty_packet(self):
|
||||
self.await_until_result(writer.write_packet_to_wire(self.interface, b""))
|
||||
|
||||
print(self.interface.data[0])
|
||||
self.assertEqual(len(self.interface.data), 1)
|
||||
self.assertEqual(self.interface.data[0], b"")
|
||||
|
||||
def test_write_empty_payload(self):
|
||||
header = PacketHeader(ENCRYPTED, 4660, 4)
|
||||
await_result(writer.write_payloads_to_wire(self.interface, header, (b"",)))
|
||||
self.assertEqual(len(self.interface.data), 0)
|
||||
|
||||
def test_write_short_payload(self):
|
||||
header = PacketHeader(ENCRYPTED, 4660, 5)
|
||||
data = b"\x07"
|
||||
self.await_until_result(
|
||||
writer.write_payloads_to_wire(self.interface, header, (data,))
|
||||
)
|
||||
self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
|
||||
|
||||
def test_write_longer_payload(self):
|
||||
data = bytearray(range(256))
|
||||
header = PacketHeader(ENCRYPTED, 4660, 256)
|
||||
self.await_until_result(
|
||||
writer.write_payloads_to_wire(self.interface, header, (data,))
|
||||
)
|
||||
|
||||
for i in range(len(self.longer_payload_expected)):
|
||||
self.assertEqual(
|
||||
hexlify(self.interface.data[i]), self.longer_payload_expected[i]
|
||||
)
|
||||
|
||||
def test_write_eight_longer_payloads(self):
|
||||
data = bytearray(range(256))
|
||||
header = PacketHeader(ENCRYPTED, 4660, 2048)
|
||||
self.await_until_result(
|
||||
writer.write_payloads_to_wire(
|
||||
self.interface, header, (data, data, data, data, data, data, data, data)
|
||||
)
|
||||
)
|
||||
for i in range(len(self.eight_longer_payloads_expected)):
|
||||
self.assertEqual(
|
||||
hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
|
||||
)
|
||||
|
||||
def test_write_empty_payload_with_checksum(self):
|
||||
header = PacketHeader(ENCRYPTED, 4660, 4)
|
||||
self.await_until_result(
|
||||
writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
|
||||
)
|
||||
|
||||
def test_write_longer_payload_with_checksum(self):
|
||||
data = bytearray(range(256))
|
||||
header = PacketHeader(ENCRYPTED, 4660, 256)
|
||||
self.await_until_result(
|
||||
writer.write_payload_to_wire_and_add_checksum(self.interface, header, data)
|
||||
)
|
||||
|
||||
for i in range(len(self.longer_payload_with_checksum_expected)):
|
||||
self.assertEqual(
|
||||
hexlify(self.interface.data[i]),
|
||||
self.longer_payload_with_checksum_expected[i],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
338
core/tests/test_trezor.wire.thp_deprecated.py
Normal file
338
core/tests/test_trezor.wire.thp_deprecated.py
Normal file
@ -0,0 +1,338 @@
|
||||
from common import * # isort:skip
|
||||
import ustruct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from mock_wire_interface import MockHID
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import io
|
||||
from trezor.utils import chunks
|
||||
from trezor.wire.protocol_common import Message
|
||||
|
||||
if utils.USE_THP:
|
||||
import thp_common
|
||||
import trezor.wire.thp
|
||||
from trezor.wire.thp import thp_main
|
||||
from trezor.wire.thp import alternating_bit_protocol as ABP
|
||||
from trezor.wire.thp import checksum
|
||||
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
|
||||
from trezor.wire.thp.writer import PACKET_LENGTH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
MESSAGE_TYPE_BYTES = b"\x42\x42"
|
||||
_MESSAGE_TYPE_LEN = 2
|
||||
PLAINTEXT_0 = 0x01
|
||||
PLAINTEXT_1 = 0x11
|
||||
COMMON_CID = 4660
|
||||
CONT = 0x80
|
||||
|
||||
HEADER_INIT_LENGTH = 5
|
||||
HEADER_CONT_LENGTH = 3
|
||||
if utils.USE_THP:
|
||||
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
||||
|
||||
|
||||
def make_header(ctrl_byte, cid, length):
|
||||
return ustruct.pack(">BHH", ctrl_byte, cid, length)
|
||||
|
||||
|
||||
def make_cont_header():
|
||||
return ustruct.pack(">BH", CONT, COMMON_CID)
|
||||
|
||||
|
||||
def makeSimpleMessage(header, message_type, message_data):
|
||||
return header + ustruct.pack(">H", message_type) + message_data
|
||||
|
||||
|
||||
def makeCidRequest(header, message_data):
|
||||
return header + message_data
|
||||
|
||||
|
||||
def getPlaintext() -> bytes:
|
||||
if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
|
||||
return PLAINTEXT_1
|
||||
return PLAINTEXT_0
|
||||
|
||||
|
||||
async def deprecated_read_message(
|
||||
iface: WireInterface, buffer: utils.BufferType
|
||||
) -> Message:
|
||||
return Message(-1, b"\x00")
|
||||
|
||||
|
||||
async def deprecated_write_message(
|
||||
iface: WireInterface, message: Message, is_retransmission: bool = False
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# This test suite is an adaptation of test_trezor.wire.codec_v1
|
||||
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
|
||||
class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
|
||||
def __init__(self):
|
||||
if __debug__:
|
||||
thp_common.suppres_debug_log()
|
||||
super().__init__()
|
||||
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
|
||||
def _simple(self):
|
||||
cid_req_header = make_header(
|
||||
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
||||
)
|
||||
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
|
||||
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
|
||||
|
||||
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
|
||||
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
|
||||
message = makeSimpleMessage(
|
||||
message_header,
|
||||
MESSAGE_TYPE,
|
||||
cid_request_dummy_data + cid_request_dummy_data_checksum,
|
||||
)
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(cid_req_message)
|
||||
gen.send(None)
|
||||
gen.send(message)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, cid_request_dummy_data)
|
||||
|
||||
buffer_without_zeroes = buffer[: len(message) - 5]
|
||||
message_without_header = message[5:]
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer_without_zeroes, message_without_header)
|
||||
|
||||
def _read_one_packet(self):
|
||||
# zero length message - just a header
|
||||
PLAINTEXT = getPlaintext()
|
||||
header = make_header(
|
||||
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||
message = header + MESSAGE_TYPE_BYTES + chksum
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(message)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, b"")
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
|
||||
|
||||
def _read_many_packets(self):
|
||||
message = bytes(range(256))
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||
)
|
||||
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||
|
||||
# first packet is init header + 59 bytes of data
|
||||
# other packets are cont header + 61 bytes of data
|
||||
cont_header = make_cont_header()
|
||||
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||
64 - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
buffer = bytearray(262)
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in packets:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
# last packet will stop
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# message should have been read into the buffer )
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
|
||||
|
||||
def _read_large_message(self):
|
||||
message = b"hello world"
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
packet = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ message
|
||||
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||
)
|
||||
|
||||
# make sure we fit into one packet, to make this easier
|
||||
self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
|
||||
|
||||
buffer = bytearray(1)
|
||||
self.assertTrue(len(buffer) <= len(packet))
|
||||
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
gen.send(packet)
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# read should have allocated its own buffer and not touch ours
|
||||
self.assertEqual(buffer, b"\x00")
|
||||
|
||||
def _roundtrip(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(
|
||||
MESSAGE_TYPE, message_payload, 1
|
||||
) # TODO use different session id
|
||||
gen = deprecated_write_message(self.interface, message)
|
||||
# exhaust the iterator:
|
||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||
for query in gen:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
buffer = bytearray(1024)
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in self.interface.data:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
print(utils.get_bytes_as_str(packet))
|
||||
query = gen.send(packet)
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(None)
|
||||
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message.data)
|
||||
|
||||
def _write_one_packet(self):
|
||||
message = Message(MESSAGE_TYPE, b"")
|
||||
gen = deprecated_write_message(self.interface, message)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
header = make_header(
|
||||
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||
)
|
||||
expected_message = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
|
||||
)
|
||||
self.assertTrue(self.interface.data == [expected_message])
|
||||
|
||||
def _write_multiple_packets(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(MESSAGE_TYPE, message_payload)
|
||||
gen = deprecated_write_message(self.interface, message)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_1,
|
||||
COMMON_CID,
|
||||
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||
)
|
||||
cont_header = make_cont_header()
|
||||
chksum = checksum.compute(
|
||||
header + message.type.to_bytes(2, "big") + message.data
|
||||
)
|
||||
packets = [
|
||||
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
|
||||
] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||
thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
|
||||
for _ in packets:
|
||||
# we receive as many queries as there are packets
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
# the first sent None only started the generator. the len(packets)-th None
|
||||
# will finish writing and raise StopIteration
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
# packets must be identical up to the last one
|
||||
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||
# last packet must be identical up to message length. remaining bytes in
|
||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||
# previous packet
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||
self.assertEqual(last_packet, self.interface.data[-1])
|
||||
|
||||
def _read_huge_packet(self):
|
||||
PACKET_COUNT = 1180
|
||||
# message that takes up 1 180 USB packets
|
||||
message_size = (PACKET_COUNT - 1) * (
|
||||
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
|
||||
) + INIT_MESSAGE_DATA_LENGTH
|
||||
|
||||
# ensure that a message this big won't fit into memory
|
||||
# Note: this control is changed, because THP has only 2 byte length field
|
||||
self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN)
|
||||
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||
buffer = bytearray(65536)
|
||||
gen = deprecated_read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
|
||||
# THP returns "Message too large" error after reading the message size,
|
||||
# it is different from codec_v1 as it does not allow big enough messages
|
||||
# to raise MemoryError in this test
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(trezor.wire.thp.ThpError) as e:
|
||||
query = gen.send(packet)
|
||||
|
||||
self.assertEqual(e.value.args[0], "Message too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
44
core/tests/thp_common.py
Normal file
44
core/tests/thp_common.py
Normal file
@ -0,0 +1,44 @@
|
||||
from trezor import utils
|
||||
from trezor.wire.thp import ChannelState
|
||||
|
||||
if utils.USE_THP:
|
||||
import unittest
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from mock_wire_interface import MockHID
|
||||
from storage import cache_thp
|
||||
from trezor.wire import context
|
||||
from trezor.wire.thp import interface_manager
|
||||
from trezor.wire.thp.channel import Channel
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
_MOCK_INTERFACE_HID = b"\x00"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import WireInterface
|
||||
|
||||
def dummy_decode_iface(cached_iface: bytes):
|
||||
return MockHID(0xDEADBEEF)
|
||||
|
||||
def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
|
||||
interface_manager.decode_iface = dummy_decode_iface
|
||||
channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
|
||||
channel = Channel(channel_cache)
|
||||
channel.set_channel_state(ChannelState.TH1)
|
||||
if channel_iface is not None:
|
||||
channel.iface = channel_iface
|
||||
return channel
|
||||
|
||||
def prepare_context() -> None:
|
||||
channel = get_new_channel()
|
||||
session_cache = cache_thp.get_new_session(channel.channel_cache)
|
||||
session_ctx = SessionContext(channel, session_cache)
|
||||
context.CURRENT_CONTEXT = session_ctx
|
||||
|
||||
|
||||
if __debug__:
|
||||
# Disable log.debug
|
||||
def suppres_debug_log() -> None:
|
||||
from trezor import log
|
||||
|
||||
log.debug = lambda name, msg, *args: None
|
@ -2,7 +2,7 @@
|
||||
|
||||
import binascii
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.transport.hid import HidTransport
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
if len(devices) > 0:
|
||||
|
@ -106,44 +106,44 @@ Frozen version. That means you do not need any other files to run it,
|
||||
it is just a single binary file that you can execute directly.
|
||||
**Are you looking for a Trezor T emulator? This is most likely it.**
|
||||
|
||||
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L317)
|
||||
### [core unix frozen R debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L318)
|
||||
|
||||
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L332)
|
||||
### [core unix frozen T3T1 debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L333)
|
||||
|
||||
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L346)
|
||||
### [core unix frozen R debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L347)
|
||||
|
||||
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L369)
|
||||
### [core unix frozen T3T1 debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L370)
|
||||
|
||||
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L392)
|
||||
### [core unix frozen debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L393)
|
||||
|
||||
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L408)
|
||||
### [core unix frozen debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L409)
|
||||
|
||||
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L430)
|
||||
### [core macos frozen regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L431)
|
||||
|
||||
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L455)
|
||||
### [crypto build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L456)
|
||||
Build of our cryptographic library, which is then incorporated into the other builds.
|
||||
|
||||
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L485)
|
||||
### [legacy fw regular build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L486)
|
||||
|
||||
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L501)
|
||||
### [legacy fw regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L502)
|
||||
|
||||
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L518)
|
||||
### [legacy fw btconly build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L519)
|
||||
|
||||
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L537)
|
||||
### [legacy fw btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L538)
|
||||
|
||||
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L558)
|
||||
### [legacy emu regular debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L559)
|
||||
Regular version (not only Bitcoin) of above.
|
||||
**Are you looking for a Trezor One emulator? This is most likely it.**
|
||||
|
||||
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L573)
|
||||
### [legacy emu regular debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L574)
|
||||
|
||||
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L591)
|
||||
### [legacy emu regular debug build arm](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L592)
|
||||
|
||||
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L617)
|
||||
### [legacy emu btconly debug build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L618)
|
||||
Build of Legacy into UNIX emulator. Use keyboard arrows to emulate button presses.
|
||||
Bitcoin-only version.
|
||||
|
||||
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L634)
|
||||
### [legacy emu btconly debug asan build](https://github.com/trezor/trezor-firmware/blob/master/ci/build.yml#L635)
|
||||
|
||||
---
|
||||
## TEST stage - [test.yml](https://github.com/trezor/trezor-firmware/blob/master/ci/test.yml)
|
||||
|
@ -191,6 +191,12 @@ void fsm_sendFailure(FailureType code, const char *text)
|
||||
case FailureType_Failure_InvalidSession:
|
||||
text = _("Invalid session");
|
||||
break;
|
||||
case FailureType_Failure_ThpUnallocatedSession:
|
||||
text = _("Unallocated session");
|
||||
break;
|
||||
case FailureType_Failure_InvalidProtocol:
|
||||
text = _("Invalid protocol");
|
||||
break;
|
||||
case FailureType_Failure_FirmwareError:
|
||||
text = _("Firmware error");
|
||||
break;
|
||||
|
@ -10,7 +10,7 @@ SKIPPED_MESSAGES := Binance Cardano DebugMonero Eos Monero Ontology Ripple SdPro
|
||||
EthereumTypedDataValueRequest EthereumTypedDataValueAck ShowDeviceTutorial \
|
||||
UnlockBootloader AuthenticateDevice AuthenticityProof \
|
||||
Solana StellarClaimClaimableBalanceOp \
|
||||
ChangeLanguage TranslationDataRequest TranslationDataAck \
|
||||
ChangeLanguage TranslationDataRequest TranslationDataAck Thp \
|
||||
SetBrightness DebugLinkOptigaSetSecMax \
|
||||
BenchmarkListNames BenchmarkRun BenchmarkNames BenchmarkResult
|
||||
|
||||
|
1
legacy/firmware/protob/messages-thp.proto
Symbolic link
1
legacy/firmware/protob/messages-thp.proto
Symbolic link
@ -0,0 +1 @@
|
||||
../../vendor/trezor-common/protob/messages-thp.proto
|
1555
poetry.lock
generated
1555
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,7 @@ trezor = {path = "./python", develop = true}
|
||||
scons = "*"
|
||||
protobuf = "*"
|
||||
nanopb = "^0.4.3"
|
||||
appdirs ="*"
|
||||
|
||||
## test tools
|
||||
pytest = "^8"
|
||||
|
@ -6,3 +6,4 @@ libusb1>=1.6.4
|
||||
construct>=2.9,!=2.10.55
|
||||
typing_extensions>=4.7.1
|
||||
construct-classes>=0.1.2
|
||||
appdirs>=1.4.4
|
||||
|
@ -112,7 +112,7 @@ class Emulator:
|
||||
start = time.monotonic()
|
||||
try:
|
||||
while True:
|
||||
if transport._ping():
|
||||
if transport.ping():
|
||||
break
|
||||
if self.process.poll() is not None:
|
||||
raise RuntimeError("Emulator process died")
|
||||
|
@ -7,7 +7,7 @@ import typing as t
|
||||
from importlib import metadata
|
||||
|
||||
from . import device
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
try:
|
||||
cryptography_version = metadata.version("cryptography")
|
||||
@ -361,7 +361,7 @@ def verify_authentication_response(
|
||||
|
||||
|
||||
def authenticate_device(
|
||||
client: TrezorClient,
|
||||
session: Session,
|
||||
challenge: bytes | None = None,
|
||||
*,
|
||||
whitelist: t.Collection[bytes] | None = None,
|
||||
@ -371,7 +371,7 @@ def authenticate_device(
|
||||
if challenge is None:
|
||||
challenge = secrets.token_bytes(16)
|
||||
|
||||
resp = device.authenticate(client, challenge)
|
||||
resp = device.authenticate(session, challenge)
|
||||
|
||||
return verify_authentication_response(
|
||||
challenge,
|
||||
|
@ -20,17 +20,17 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.BenchmarkNames)
|
||||
def list_names(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
) -> "MessageType":
|
||||
return client.call(messages.BenchmarkListNames())
|
||||
return session.call(messages.BenchmarkListNames())
|
||||
|
||||
|
||||
@expect(messages.BenchmarkResult)
|
||||
def run(client: "TrezorClient", name: str) -> "MessageType":
|
||||
return client.call(messages.BenchmarkRun(name=name))
|
||||
def run(session: "Session", name: str) -> "MessageType":
|
||||
return session.call(messages.BenchmarkRun(name=name))
|
||||
|
@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import expect, session
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -42,16 +42,15 @@ def get_address(
|
||||
|
||||
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
session: "Session", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
) -> messages.BinanceSignedTx:
|
||||
msg = tx_json["msgs"][0]
|
||||
tx_msg = tx_json.copy()
|
||||
@ -60,7 +59,7 @@ def sign_tx(
|
||||
tx_msg["chunkify"] = chunkify
|
||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||
|
||||
response = client.call(envelope)
|
||||
response = session.call(envelope)
|
||||
|
||||
if not isinstance(response, messages.BinanceTxRequest):
|
||||
raise RuntimeError(
|
||||
@ -77,7 +76,7 @@ def sign_tx(
|
||||
else:
|
||||
raise ValueError("can not determine msg type")
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
|
||||
if not isinstance(response, messages.BinanceSignedTx):
|
||||
raise RuntimeError(
|
||||
|
@ -13,7 +13,6 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import warnings
|
||||
from copy import copy
|
||||
from decimal import Decimal
|
||||
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol, TypedDict
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect, prepare_message_bytes, session
|
||||
from .tools import expect, prepare_message_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
class ScriptSig(TypedDict):
|
||||
asm: str
|
||||
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||
|
||||
@expect(messages.PublicKey)
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
show_display: bool = False,
|
||||
@ -116,13 +115,13 @@ def get_public_node(
|
||||
unlock_path_mac: Optional[bytes] = None,
|
||||
) -> "MessageType":
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
ecdsa_curve_name=ecdsa_curve_name,
|
||||
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
|
||||
|
||||
@expect(messages.Address)
|
||||
def get_authenticated_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
@ -153,13 +152,13 @@ def get_authenticated_address(
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -172,15 +171,16 @@ def get_authenticated_address(
|
||||
)
|
||||
|
||||
|
||||
# TODO this is used by tests only
|
||||
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
||||
def get_ownership_id(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetOwnershipId(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -190,8 +190,9 @@ def get_ownership_id(
|
||||
)
|
||||
|
||||
|
||||
# TODO this is used by tests only
|
||||
def get_ownership_proof(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
@ -202,11 +203,11 @@ def get_ownership_proof(
|
||||
preauthorized: bool = False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
if preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
res = session.call(messages.DoPreauthorized())
|
||||
if not isinstance(res, messages.PreauthorizedRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.GetOwnershipProof(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -226,7 +227,7 @@ def get_ownership_proof(
|
||||
|
||||
@expect(messages.MessageSignature)
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
@ -234,7 +235,7 @@ def sign_message(
|
||||
no_script_type: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name,
|
||||
address_n=n,
|
||||
@ -247,7 +248,7 @@ def sign_message(
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
address: str,
|
||||
signature: bytes,
|
||||
@ -255,7 +256,7 @@ def verify_message(
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -269,9 +270,9 @@ def verify_message(
|
||||
return isinstance(resp, messages.Success)
|
||||
|
||||
|
||||
@session
|
||||
# @session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
@ -319,17 +320,17 @@ def sign_tx(
|
||||
setattr(signtx, name, value)
|
||||
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
elif preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
res = session.call(messages.DoPreauthorized())
|
||||
if not isinstance(res, messages.PreauthorizedRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
res = client.call(signtx)
|
||||
res = session.call(signtx)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||
@ -388,7 +389,7 @@ def sign_tx(
|
||||
if res.request_type == R.TXPAYMENTREQ:
|
||||
assert res.details.request_index is not None
|
||||
msg = payment_reqs[res.details.request_index]
|
||||
res = client.call(msg)
|
||||
res = session.call(msg)
|
||||
else:
|
||||
msg = messages.TransactionType()
|
||||
if res.request_type == R.TXMETA:
|
||||
@ -418,7 +419,7 @@ def sign_tx(
|
||||
f"Unknown request type - {res.request_type}."
|
||||
)
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
res = session.call(messages.TxAck(tx=msg))
|
||||
|
||||
if not isinstance(res, messages.TxRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
@ -432,7 +433,7 @@ def sign_tx(
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def authorize_coinjoin(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coordinator: str,
|
||||
max_rounds: int,
|
||||
max_coordinator_fee_rate: int,
|
||||
@ -441,7 +442,7 @@ def authorize_coinjoin(
|
||||
coin_name: str,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.AuthorizeCoinJoin(
|
||||
coordinator=coordinator,
|
||||
max_rounds=max_rounds,
|
||||
|
@ -35,8 +35,8 @@ from . import exceptions, messages, tools
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
PROTOCOL_MAGICS = {
|
||||
"mainnet": 764824073,
|
||||
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
|
||||
|
||||
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_parameters: messages.CardanoAddressParametersType,
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
@ -833,7 +833,7 @@ def get_address(
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetAddress(
|
||||
address_parameters=address_parameters,
|
||||
protocol_magic=protocol_magic,
|
||||
@ -847,12 +847,12 @@ def get_address(
|
||||
|
||||
@expect(messages.CardanoPublicKey)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
show_display: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetPublicKey(
|
||||
address_n=address_n,
|
||||
derivation_type=derivation_type,
|
||||
@ -863,12 +863,12 @@ def get_public_key(
|
||||
|
||||
@expect(messages.CardanoNativeScriptHash)
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
native_script: messages.CardanoNativeScript,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetNativeScriptHash(
|
||||
script=native_script,
|
||||
display_format=display_format,
|
||||
@ -878,7 +878,7 @@ def get_native_script_hash(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
inputs: List[InputWithPath],
|
||||
outputs: List[OutputWithData],
|
||||
@ -915,7 +915,7 @@ def sign_tx(
|
||||
signing_mode,
|
||||
)
|
||||
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.CardanoSignTxInit(
|
||||
signing_mode=signing_mode,
|
||||
inputs_count=len(inputs),
|
||||
@ -951,14 +951,14 @@ def sign_tx(
|
||||
_get_certificates_items(certificates),
|
||||
withdrawals,
|
||||
):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
sign_tx_response: Dict[str, Any] = {}
|
||||
|
||||
if auxiliary_data is not None:
|
||||
auxiliary_data_supplement = client.call(auxiliary_data)
|
||||
auxiliary_data_supplement = session.call(auxiliary_data)
|
||||
if not isinstance(
|
||||
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
|
||||
):
|
||||
@ -971,7 +971,7 @@ def sign_tx(
|
||||
auxiliary_data_supplement.__dict__
|
||||
)
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
@ -980,24 +980,24 @@ def sign_tx(
|
||||
_get_collateral_inputs_items(collateral_inputs),
|
||||
required_signers,
|
||||
):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
if collateral_return is not None:
|
||||
for tx_item in _get_output_items(collateral_return):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
for reference_input in reference_inputs:
|
||||
response = client.call(reference_input)
|
||||
response = session.call(reference_input)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
sign_tx_response["witnesses"] = []
|
||||
for witness_request in witness_requests:
|
||||
response = client.call(witness_request)
|
||||
response = session.call(witness_request)
|
||||
if not isinstance(response, messages.CardanoTxWitnessResponse):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
sign_tx_response["witnesses"].append(
|
||||
@ -1009,12 +1009,12 @@ def sign_tx(
|
||||
}
|
||||
)
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoTxBodyHash):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
sign_tx_response["tx_hash"] = response.tx_hash
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoSignTxFinished):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
|
@ -14,33 +14,41 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import TrezorClient
|
||||
from ..ui import ClickUI, ScriptUI
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.thp import channel_database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from ..transport import Transport
|
||||
from ..ui import TrezorClientUI
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ChoiceType(click.Choice):
|
||||
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
|
||||
|
||||
def __init__(
|
||||
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
|
||||
) -> None:
|
||||
super().__init__(list(typemap.keys()))
|
||||
self.case_sensitive = case_sensitive
|
||||
if case_sensitive:
|
||||
@ -48,7 +56,7 @@ class ChoiceType(click.Choice):
|
||||
else:
|
||||
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
||||
|
||||
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
|
||||
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
|
||||
if value in self.typemap.values():
|
||||
return value
|
||||
value = super().convert(value, param, ctx)
|
||||
@ -57,11 +65,48 @@ class ChoiceType(click.Choice):
|
||||
return self.typemap[value]
|
||||
|
||||
|
||||
class TrezorConnection:
|
||||
def get_passphrase(
|
||||
passphrase_on_host: bool, available_on_device: bool
|
||||
) -> t.Union[str, object]:
|
||||
if available_on_device and not passphrase_on_host:
|
||||
return ui.PASSPHRASE_ON_DEVICE
|
||||
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return env_passphrase
|
||||
|
||||
while True:
|
||||
try:
|
||||
passphrase = ui.prompt(
|
||||
"Passphrase required",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
# In case user sees the input on the screen, we do not need confirmation
|
||||
if not ui.CAN_HANDLE_HIDDEN_INPUT:
|
||||
return passphrase
|
||||
second = ui.prompt(
|
||||
"Confirm your passphrase",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
else:
|
||||
ui.echo("Passphrase did not match. Please try again.")
|
||||
except click.Abort:
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
class NewTrezorConnection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
session_id: Optional[bytes],
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
) -> None:
|
||||
@ -70,6 +115,29 @@ class TrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(self, derive_cardano: bool = False):
|
||||
client = self.get_client()
|
||||
|
||||
if self.session_id is not None:
|
||||
pass # TODO Try resume - be careful of cardano derivation settings!
|
||||
features = client.protocol.get_features()
|
||||
|
||||
passphrase_enabled = True # TODO what to do here?
|
||||
|
||||
if not passphrase_enabled:
|
||||
return client.get_session(derive_cardano=derive_cardano)
|
||||
|
||||
# TODO Passphrase empty by default - ???
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
)
|
||||
return session
|
||||
|
||||
def get_transport(self) -> "Transport":
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
@ -82,19 +150,33 @@ class TrezorConnection:
|
||||
# if this fails, we want the exception to bubble up to the caller
|
||||
return transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
def get_ui(self) -> "TrezorClientUI":
|
||||
if self.script:
|
||||
# It is alright to return just the class object instead of instance,
|
||||
# as the ScriptUI class object itself is the implementation of TrezorClientUI
|
||||
# (ScriptUI is just a set of staticmethods)
|
||||
return ScriptUI
|
||||
else:
|
||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||
|
||||
def get_client(self) -> TrezorClient:
|
||||
transport = self.get_transport()
|
||||
ui = self.get_ui()
|
||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||
|
||||
stored_channels = channel_database.load_stored_channels()
|
||||
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||
path = transport.get_path()
|
||||
if path in stored_transport_paths:
|
||||
stored_channel_with_correct_transport_path = next(
|
||||
ch for ch in stored_channels if ch.transport_path == path
|
||||
)
|
||||
try:
|
||||
client = TrezorClient.resume(
|
||||
transport, stored_channel_with_correct_transport_path
|
||||
)
|
||||
except Exception:
|
||||
LOG.debug("Failed to resume a channel. Replacing by a new one.")
|
||||
channel_database.remove_channel(path)
|
||||
client = TrezorClient(transport)
|
||||
else:
|
||||
client = TrezorClient(transport)
|
||||
|
||||
return client
|
||||
|
||||
def get_management_session(self) -> Session:
|
||||
client = self.get_client()
|
||||
management_session = client.get_management_session()
|
||||
return management_session
|
||||
|
||||
@contextmanager
|
||||
def client_context(self):
|
||||
@ -128,7 +210,131 @@ class TrezorConnection:
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
|
||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||
# class TrezorConnection:
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# path: str,
|
||||
# session_id: bytes | None,
|
||||
# passphrase_on_host: bool,
|
||||
# script: bool,
|
||||
# ) -> None:
|
||||
# self.path = path
|
||||
# self.session_id = session_id
|
||||
# self.passphrase_on_host = passphrase_on_host
|
||||
# self.script = script
|
||||
|
||||
# def get_transport(self) -> "Transport":
|
||||
# try:
|
||||
# # look for transport without prefix search
|
||||
# return transport.get_transport(self.path, prefix_search=False)
|
||||
# except Exception:
|
||||
# # most likely not found. try again below.
|
||||
# pass
|
||||
|
||||
# # look for transport with prefix search
|
||||
# # if this fails, we want the exception to bubble up to the caller
|
||||
# return transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
# def get_ui(self) -> "TrezorClientUI":
|
||||
# if self.script:
|
||||
# # It is alright to return just the class object instead of instance,
|
||||
# # as the ScriptUI class object itself is the implementation of TrezorClientUI
|
||||
# # (ScriptUI is just a set of staticmethods)
|
||||
# return ScriptUI
|
||||
# else:
|
||||
# return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||
|
||||
# def get_client(self) -> TrezorClient:
|
||||
# transport = self.get_transport()
|
||||
# ui = self.get_ui()
|
||||
# return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||
|
||||
# @contextmanager
|
||||
# def client_context(self):
|
||||
# """Get a client instance as a context manager. Handle errors in a manner
|
||||
# appropriate for end-users.
|
||||
|
||||
# Usage:
|
||||
# >>> with obj.client_context() as client:
|
||||
# >>> do_your_actions_here()
|
||||
# """
|
||||
# try:
|
||||
# client = self.get_client()
|
||||
# except transport.DeviceIsBusy:
|
||||
# click.echo("Device is in use by another process.")
|
||||
# sys.exit(1)
|
||||
# except Exception:
|
||||
# click.echo("Failed to find a Trezor device.")
|
||||
# if self.path is not None:
|
||||
# click.echo(f"Using path: {self.path}")
|
||||
# sys.exit(1)
|
||||
|
||||
# try:
|
||||
# yield client
|
||||
# except exceptions.Cancelled:
|
||||
# # handle cancel action
|
||||
# click.echo("Action was cancelled.")
|
||||
# sys.exit(1)
|
||||
# except exceptions.TrezorException as e:
|
||||
# # handle any Trezor-sent exceptions as user-readable
|
||||
# raise click.ClickException(str(e)) from e
|
||||
# # other exceptions may cause a traceback
|
||||
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
def with_cardano_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
return with_session(func=func, derive_cardano=True)
|
||||
|
||||
|
||||
def with_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_session(
|
||||
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
session = obj.get_session(derive_cardano)
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
# TODO try end session if not resumed
|
||||
|
||||
# the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
return function_with_session # type: ignore [is incompatible with return type]
|
||||
|
||||
|
||||
def with_management_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_management_session(
|
||||
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
session = obj.get_management_session()
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
# TODO try end session if not resumed
|
||||
|
||||
# the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
return function_with_management_session # type: ignore [is incompatible with return type]
|
||||
|
||||
|
||||
def with_client(
|
||||
func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
@ -139,28 +345,66 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def trezorctl_command_with_client(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
with obj.client_context() as client:
|
||||
session_was_resumed = obj.session_id == client.session_id
|
||||
if not session_was_resumed and obj.session_id is not None:
|
||||
# tried to resume but failed
|
||||
click.echo("Warning: failed to resume session.", err=True)
|
||||
|
||||
# session_was_resumed = obj.session_id == client.session_id
|
||||
# if not session_was_resumed and obj.session_id is not None:
|
||||
# # tried to resume but failed
|
||||
# click.echo("Warning: failed to resume session.", err=True)
|
||||
click.echo(
|
||||
"Warning: resume session detection is not implemented yet!", err=True
|
||||
)
|
||||
try:
|
||||
return func(client, *args, **kwargs)
|
||||
finally:
|
||||
if not session_was_resumed:
|
||||
try:
|
||||
client.end_session()
|
||||
except Exception:
|
||||
pass
|
||||
channel_database.save_channel(client.protocol)
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
# client.end_session()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
return trezorctl_command_with_client # type: ignore [is incompatible with return type]
|
||||
|
||||
|
||||
# def with_client(
|
||||
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||
# ) -> "t.Callable[P, R]":
|
||||
# """Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
# Sessions are handled transparently. The user is warned when session did not resume
|
||||
# cleanly. The session is closed after the command completes - unless the session
|
||||
# was resumed, in which case it should remain open.
|
||||
# """
|
||||
|
||||
# @click.pass_obj
|
||||
# @functools.wraps(func)
|
||||
# def trezorctl_command_with_client(
|
||||
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
# ) -> "R":
|
||||
# with obj.client_context() as client:
|
||||
# session_was_resumed = obj.session_id == client.session_id
|
||||
# if not session_was_resumed and obj.session_id is not None:
|
||||
# # tried to resume but failed
|
||||
# click.echo("Warning: failed to resume session.", err=True)
|
||||
|
||||
# try:
|
||||
# return func(client, *args, **kwargs)
|
||||
# finally:
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
# client.end_session()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
# return trezorctl_command_with_client
|
||||
|
||||
|
||||
class AliasedGroup(click.Group):
|
||||
"""Command group that handles aliases and Click 6.x compatibility.
|
||||
|
||||
@ -190,14 +434,14 @@ class AliasedGroup(click.Group):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aliases: Optional[Dict[str, click.Command]] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
aliases: t.Dict[str, click.Command] | None = None,
|
||||
*args: t.Any,
|
||||
**kwargs: t.Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aliases = aliases or {}
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
cmd_name = cmd_name.replace("_", "-")
|
||||
# try to look up the real name
|
||||
cmd = super().get_command(ctx, cmd_name)
|
||||
|
@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import click
|
||||
|
||||
from .. import benchmark
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
def list_names_patern(
|
||||
client: "TrezorClient", pattern: Optional[str] = None
|
||||
) -> List[str]:
|
||||
names = list(benchmark.list_names(client).names)
|
||||
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
|
||||
names = list(benchmark.list_names(session).names)
|
||||
if pattern is None:
|
||||
return names
|
||||
return [name for name in names if fnmatch(name, pattern)]
|
||||
@ -43,10 +41,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
@with_session
|
||||
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||
"""List names of all supported benchmarks"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
|
||||
@with_session
|
||||
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||
"""Run benchmark"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
for name in names:
|
||||
result = benchmark.run(client, name)
|
||||
result = benchmark.run(session, name)
|
||||
click.echo(f"{name}: {result.value} {result.unit}")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import binance, tools
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
|
||||
@ -39,23 +39,23 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Binance address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_address(client, address_n, show_display, chunkify)
|
||||
return binance.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Binance public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_public_key(client, address_n, show_display).hex()
|
||||
return binance.get_public_key(session, address_n, show_display).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.BinanceSignedTx":
|
||||
"""Sign Binance transaction.
|
||||
|
||||
Transaction must be provided as a JSON file.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
@ -13,6 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
@ -22,10 +23,10 @@ import click
|
||||
import construct as c
|
||||
|
||||
from .. import btc, messages, protobuf, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PURPOSE_BIP44 = 44
|
||||
PURPOSE_BIP48 = 48
|
||||
@ -168,15 +169,15 @@ def cli() -> None:
|
||||
default=2,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
script_type: Optional[messages.InputScriptType],
|
||||
script_type: messages.InputScriptType | None,
|
||||
show_display: bool,
|
||||
multisig_xpub: List[str],
|
||||
multisig_threshold: Optional[int],
|
||||
multisig_threshold: int | None,
|
||||
multisig_suffix_length: int,
|
||||
chunkify: bool,
|
||||
) -> str:
|
||||
@ -220,7 +221,7 @@ def get_address(
|
||||
multisig = None
|
||||
|
||||
return btc.get_address(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
show_display,
|
||||
@ -237,9 +238,9 @@ def get_address(
|
||||
@click.option("-e", "--curve")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
curve: Optional[str],
|
||||
@ -251,7 +252,7 @@ def get_public_node(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
result = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
ecdsa_curve_name=curve,
|
||||
show_display=show_display,
|
||||
@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str:
|
||||
|
||||
|
||||
def _get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
purpose: Optional[int],
|
||||
@ -311,7 +312,7 @@ def _get_descriptor(
|
||||
|
||||
n = tools.parse_path(path)
|
||||
pub = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
n,
|
||||
show_display=show_display,
|
||||
coin_name=coin,
|
||||
@ -348,9 +349,9 @@ def _get_descriptor(
|
||||
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
account_type: Optional[int],
|
||||
@ -360,7 +361,7 @@ def get_descriptor(
|
||||
"""Get descriptor of given account."""
|
||||
try:
|
||||
return _get_descriptor(
|
||||
client, coin, account, account_type, script_type, show_display
|
||||
session, coin, account, account_type, script_type, show_display
|
||||
)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
@ -375,8 +376,8 @@ def get_descriptor(
|
||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("json_file", type=click.File())
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign transaction.
|
||||
|
||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||
@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
}
|
||||
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
message: str,
|
||||
@ -447,7 +448,7 @@ def sign_message(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
res = btc.sign_message(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
message,
|
||||
@ -468,9 +469,9 @@ def sign_message(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
signature: str,
|
||||
@ -480,7 +481,7 @@ def verify_message(
|
||||
"""Verify message."""
|
||||
signature_bytes = base64.b64decode(signature)
|
||||
return btc.verify_message(
|
||||
client, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
session, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import cardano, messages, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_cardano_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
|
||||
|
||||
@ -62,9 +62,9 @@ def cli() -> None:
|
||||
@click.option("-i", "--include-network-id", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.option("-T", "--tag-cbor-sets", is_flag=True)
|
||||
@with_client
|
||||
@with_cardano_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
protocol_magic: int,
|
||||
@ -123,9 +123,8 @@ def sign_tx(
|
||||
for p in transaction["additional_witness_requests"]
|
||||
]
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
sign_tx_response = cardano.sign_tx(
|
||||
client,
|
||||
session,
|
||||
signing_mode,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -209,9 +208,9 @@ def sign_tx(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_cardano_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
address_type: messages.CardanoAddressType,
|
||||
staking_address: str,
|
||||
@ -262,9 +261,8 @@ def get_address(
|
||||
script_staking_hash_bytes,
|
||||
)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_address(
|
||||
client,
|
||||
session,
|
||||
address_parameters,
|
||||
protocol_magic,
|
||||
network_id,
|
||||
@ -283,18 +281,17 @@ def get_address(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_cardano_session
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
show_display: bool,
|
||||
) -> messages.CardanoPublicKey:
|
||||
"""Get Cardano public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_public_key(
|
||||
client, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
session, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
)
|
||||
|
||||
|
||||
@ -312,9 +309,9 @@ def get_public_key(
|
||||
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@with_client
|
||||
@with_cardano_session
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
@ -323,7 +320,6 @@ def get_native_script_hash(
|
||||
native_script_json = json.load(file)
|
||||
native_script = cardano.parse_native_script(native_script_json)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_native_script_hash(
|
||||
client, native_script, display_format, derivation_type=derivation_type
|
||||
session, native_script, display_format, derivation_type=derivation_type
|
||||
)
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import click
|
||||
|
||||
from .. import misc, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
PROMPT_TYPE = ChoiceType(
|
||||
@ -42,10 +42,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("size", type=int)
|
||||
@with_client
|
||||
def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
@with_session
|
||||
def get_entropy(session: "Session", size: int) -> str:
|
||||
"""Get random bytes from device."""
|
||||
return misc.get_entropy(client, size).hex()
|
||||
return misc.get_entropy(session, size).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -75,7 +75,7 @@ def encrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.encrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
value.encode(),
|
||||
@ -91,9 +91,9 @@ def encrypt_keyvalue(
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -112,7 +112,7 @@ def decrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.decrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
bytes.fromhex(value),
|
||||
|
@ -18,16 +18,15 @@ from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from .. import mapping, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
|
||||
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
|
||||
from ..debuglink import record_screen
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_management_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import TrezorConnection
|
||||
from . import NewTrezorConnection
|
||||
|
||||
|
||||
@click.group(name="debug")
|
||||
@ -35,58 +34,58 @@ def cli() -> None:
|
||||
"""Miscellaneous debug features."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("message_name_or_type")
|
||||
@click.argument("hex_data")
|
||||
@click.pass_obj
|
||||
def send_bytes(
|
||||
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
||||
) -> None:
|
||||
"""Send raw bytes to Trezor.
|
||||
# @cli.command()
|
||||
# @click.argument("message_name_or_type")
|
||||
# @click.argument("hex_data")
|
||||
# @click.pass_obj
|
||||
# def send_bytes(
|
||||
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
|
||||
# ) -> None:
|
||||
# """Send raw bytes to Trezor.
|
||||
|
||||
Message type and message data must be specified separately, due to how message
|
||||
chunking works on the transport level. Message length is calculated and sent
|
||||
automatically, and it is currently impossible to explicitly specify invalid length.
|
||||
# Message type and message data must be specified separately, due to how message
|
||||
# chunking works on the transport level. Message length is calculated and sent
|
||||
# automatically, and it is currently impossible to explicitly specify invalid length.
|
||||
|
||||
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||
in which case the value of that enum is used.
|
||||
"""
|
||||
if message_name_or_type.isdigit():
|
||||
message_type = int(message_name_or_type)
|
||||
else:
|
||||
message_type = getattr(messages.MessageType, message_name_or_type)
|
||||
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||
# in which case the value of that enum is used.
|
||||
# """
|
||||
# if message_name_or_type.isdigit():
|
||||
# message_type = int(message_name_or_type)
|
||||
# else:
|
||||
# message_type = getattr(messages.MessageType, message_name_or_type)
|
||||
|
||||
if not isinstance(message_type, int):
|
||||
raise click.ClickException("Invalid message type.")
|
||||
# if not isinstance(message_type, int):
|
||||
# raise click.ClickException("Invalid message type.")
|
||||
|
||||
try:
|
||||
message_data = bytes.fromhex(hex_data)
|
||||
except Exception as e:
|
||||
raise click.ClickException("Invalid hex data.") from e
|
||||
# try:
|
||||
# message_data = bytes.fromhex(hex_data)
|
||||
# except Exception as e:
|
||||
# raise click.ClickException("Invalid hex data.") from e
|
||||
|
||||
transport = obj.get_transport()
|
||||
transport.begin_session()
|
||||
transport.write(message_type, message_data)
|
||||
# transport = obj.get_transport()
|
||||
# transport.deprecated_begin_session()
|
||||
# transport.write(message_type, message_data)
|
||||
|
||||
response_type, response_data = transport.read()
|
||||
transport.end_session()
|
||||
# response_type, response_data = transport.read()
|
||||
# transport.deprecated_end_session()
|
||||
|
||||
click.echo(f"Response type: {response_type}")
|
||||
click.echo(f"Response data: {response_data.hex()}")
|
||||
# click.echo(f"Response type: {response_type}")
|
||||
# click.echo(f"Response data: {response_data.hex()}")
|
||||
|
||||
try:
|
||||
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
click.echo("Parsed message:")
|
||||
click.echo(protobuf.format_message(msg))
|
||||
except Exception as e:
|
||||
click.echo(f"Could not parse response: {e}")
|
||||
# try:
|
||||
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
# click.echo("Parsed message:")
|
||||
# click.echo(protobuf.format_message(msg))
|
||||
# except Exception as e:
|
||||
# click.echo(f"Could not parse response: {e}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("directory", required=False)
|
||||
@click.option("-s", "--stop", is_flag=True, help="Stop the recording")
|
||||
@click.pass_obj
|
||||
def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) -> None:
|
||||
def record(obj: "NewTrezorConnection", directory: Union[str, None], stop: bool) -> None:
|
||||
"""Record screen changes into a specified directory.
|
||||
|
||||
Recording can be stopped with `-s / --stop` option.
|
||||
@ -95,7 +94,7 @@ def record(obj: "TrezorConnection", directory: Union[str, None], stop: bool) ->
|
||||
|
||||
|
||||
def record_screen_from_connection(
|
||||
obj: "TrezorConnection", directory: Union[str, None]
|
||||
obj: "NewTrezorConnection", directory: Union[str, None]
|
||||
) -> None:
|
||||
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
|
||||
transport = obj.get_transport()
|
||||
@ -106,17 +105,17 @@ def record_screen_from_connection(
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def prodtest_t1(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def prodtest_t1(session: "Session") -> str:
|
||||
"""Perform a prodtest on Model One.
|
||||
|
||||
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
|
||||
"""
|
||||
return debuglink_prodtest_t1(client)
|
||||
return debuglink_prodtest_t1(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def optiga_set_sec_max(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def optiga_set_sec_max(session: "Session") -> str:
|
||||
"""Set Optiga's security event counter to maximum."""
|
||||
return debuglink_optiga_set_sec_max(client)
|
||||
return debuglink_optiga_set_sec_max(session)
|
||||
|
@ -24,12 +24,12 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import debuglink, device, exceptions, messages, ui
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_management_session
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..protobuf import MessageType
|
||||
from . import TrezorConnection
|
||||
from ..transport.session import Session
|
||||
from . import NewTrezorConnection
|
||||
|
||||
RECOVERY_DEVICE_INPUT_METHOD = {
|
||||
"scrambled": messages.RecoveryDeviceInputMethod.ScrambledWords,
|
||||
@ -64,17 +64,18 @@ def cli() -> None:
|
||||
help="Wipe device in bootloader mode. This also erases the firmware.",
|
||||
is_flag=True,
|
||||
)
|
||||
@with_client
|
||||
def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
@with_management_session
|
||||
def wipe(session: "Session", bootloader: bool) -> str:
|
||||
"""Reset device to factory defaults and remove all private data."""
|
||||
features = session.features
|
||||
if bootloader:
|
||||
if not client.features.bootloader_mode:
|
||||
if not features.bootloader_mode:
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Wiping user data and firmware!")
|
||||
else:
|
||||
if client.features.bootloader_mode:
|
||||
if features.bootloader_mode:
|
||||
click.echo(
|
||||
"Your device is in bootloader mode. This operation would also erase firmware."
|
||||
)
|
||||
@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
click.echo("Wiping user data!")
|
||||
|
||||
try:
|
||||
return device.wipe(client)
|
||||
return device.wipe(
|
||||
session
|
||||
) # TODO decide where the wipe should happen - management or regular session
|
||||
except exceptions.TrezorFailure as e:
|
||||
click.echo("Action failed: {} {}".format(*e.args))
|
||||
sys.exit(3)
|
||||
@ -103,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
@click.option("-a", "--academic", is_flag=True)
|
||||
@click.option("-b", "--needs-backup", is_flag=True)
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_management_session
|
||||
def load(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
mnemonic: t.Sequence[str],
|
||||
pin: str,
|
||||
passphrase_protection: bool,
|
||||
@ -136,7 +139,7 @@ def load(
|
||||
|
||||
try:
|
||||
return debuglink.load_device(
|
||||
client,
|
||||
session,
|
||||
mnemonic=list(mnemonic),
|
||||
pin=pin,
|
||||
passphrase_protection=passphrase_protection,
|
||||
@ -171,9 +174,9 @@ def load(
|
||||
)
|
||||
@click.option("-d", "--dry-run", is_flag=True)
|
||||
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_management_session
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
words: str,
|
||||
expand: bool,
|
||||
pin_protection: bool,
|
||||
@ -201,7 +204,7 @@ def recover(
|
||||
type = messages.RecoveryType.UnlockRepeatedBackup
|
||||
|
||||
return device.recover(
|
||||
client,
|
||||
session,
|
||||
word_count=int(words),
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -222,9 +225,9 @@ def recover(
|
||||
@click.option("-s", "--skip-backup", is_flag=True)
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
|
||||
@with_client
|
||||
@with_management_session
|
||||
def setup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
strength: int | None,
|
||||
passphrase_protection: bool,
|
||||
pin_protection: bool,
|
||||
@ -241,7 +244,7 @@ def setup(
|
||||
BT = messages.BackupType
|
||||
|
||||
if backup_type is None:
|
||||
if client.version >= (2, 7, 1):
|
||||
if session.version >= (2, 7, 1):
|
||||
# SLIP39 extendable was introduced in 2.7.1
|
||||
backup_type = BT.Slip39_Single_Extendable
|
||||
else:
|
||||
@ -251,10 +254,10 @@ def setup(
|
||||
if (
|
||||
backup_type
|
||||
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
|
||||
and messages.Capability.Shamir not in client.features.capabilities
|
||||
and messages.Capability.Shamir not in session.features.capabilities
|
||||
) or (
|
||||
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
|
||||
and messages.Capability.ShamirGroups not in client.features.capabilities
|
||||
and messages.Capability.ShamirGroups not in session.features.capabilities
|
||||
):
|
||||
click.echo(
|
||||
"WARNING: Your Trezor device does not indicate support for the requested\n"
|
||||
@ -262,7 +265,7 @@ def setup(
|
||||
)
|
||||
|
||||
return device.reset(
|
||||
client,
|
||||
session,
|
||||
strength=strength,
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -277,23 +280,21 @@ def setup(
|
||||
@cli.command()
|
||||
@click.option("-t", "--group-threshold", type=int)
|
||||
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
|
||||
@with_client
|
||||
@with_management_session
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: int | None = None,
|
||||
groups: t.Sequence[tuple[int, int]] = (),
|
||||
) -> str:
|
||||
"""Perform device seed backup."""
|
||||
|
||||
return device.backup(client, group_threshold, groups)
|
||||
return device.backup(session, group_threshold, groups)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||
@with_client
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
) -> str:
|
||||
@with_management_session
|
||||
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
|
||||
"""Secure the device with SD card protection.
|
||||
|
||||
When SD card protection is enabled, a randomly generated secret is stored
|
||||
@ -307,36 +308,36 @@ def sd_protect(
|
||||
off - Remove SD card secret protection.
|
||||
refresh - Replace the current SD card secret with a new one.
|
||||
"""
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||
return device.sd_protect(client, operation)
|
||||
return device.sd_protect(session, operation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
|
||||
def reboot_to_bootloader(obj: "NewTrezorConnection") -> str:
|
||||
"""Reboot device into bootloader mode.
|
||||
|
||||
Currently only supported on Trezor Model One.
|
||||
"""
|
||||
# avoid using @with_client because it closes the session afterwards,
|
||||
# avoid using @with_management_session because it closes the session afterwards,
|
||||
# which triggers double prompt on device
|
||||
with obj.client_context() as client:
|
||||
return device.reboot_to_bootloader(client)
|
||||
return device.reboot_to_bootloader(client.get_management_session())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def tutorial(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def tutorial(session: "Session") -> str:
|
||||
"""Show on-device tutorial."""
|
||||
return device.show_device_tutorial(client)
|
||||
return device.show_device_tutorial(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def unlock_bootloader(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def unlock_bootloader(session: "Session") -> str:
|
||||
"""Unlocks bootloader. Irreversible."""
|
||||
return device.unlock_bootloader(client)
|
||||
return device.unlock_bootloader(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -347,11 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
|
||||
type=int,
|
||||
help="Dialog expiry in seconds.",
|
||||
)
|
||||
@with_client
|
||||
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
|
||||
@with_management_session
|
||||
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
|
||||
"""Show a "Do not disconnect" dialog."""
|
||||
if enable is False:
|
||||
return device.set_busy(client, None)
|
||||
return device.set_busy(session, None)
|
||||
|
||||
if expiry is None:
|
||||
raise click.ClickException("Missing option '-e' / '--expiry'.")
|
||||
@ -361,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
|
||||
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
|
||||
)
|
||||
|
||||
return device.set_busy(client, expiry * 1000)
|
||||
return device.set_busy(session, expiry * 1000)
|
||||
|
||||
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
@ -381,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
is_flag=True,
|
||||
help="Do not check intermediate certificates against the whitelist.",
|
||||
)
|
||||
@with_client
|
||||
@with_management_session
|
||||
def authenticate(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
hex_challenge: str | None,
|
||||
root: t.BinaryIO | None,
|
||||
raw: bool | None,
|
||||
@ -408,7 +409,7 @@ def authenticate(
|
||||
challenge = bytes.fromhex(hex_challenge)
|
||||
|
||||
if raw:
|
||||
msg = device.authenticate(client, challenge)
|
||||
msg = device.authenticate(session, challenge)
|
||||
|
||||
click.echo(f"Challenge: {hex_challenge}")
|
||||
click.echo(f"Signature of challenge: {msg.signature.hex()}")
|
||||
@ -456,14 +457,14 @@ def authenticate(
|
||||
else:
|
||||
whitelist_json = requests.get(
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE.format(
|
||||
model=client.model.internal_name.lower()
|
||||
model=session.model.internal_name.lower()
|
||||
)
|
||||
).json()
|
||||
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
|
||||
|
||||
try:
|
||||
authentication.authenticate_device(
|
||||
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
)
|
||||
except authentication.DeviceNotAuthentic:
|
||||
click.echo("Device is not authentic.")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import eos, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
|
||||
|
||||
@ -37,11 +37,11 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Eos public key in base58 encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = eos.get_public_key(client, address_n, show_display)
|
||||
res = eos.get_public_key(session, address_n, show_display)
|
||||
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
|
||||
|
||||
|
||||
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.EosSignedTx":
|
||||
"""Sign EOS transaction."""
|
||||
tx_json = json.load(file)
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
return eos.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
tx_json["transaction"],
|
||||
tx_json["chain_id"],
|
||||
|
@ -26,14 +26,14 @@ import click
|
||||
|
||||
from .. import _rlp, definitions, ethereum, tools
|
||||
from ..messages import EthereumDefinitions
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import web3
|
||||
from eth_typing import ChecksumAddress # noqa: I900
|
||||
from web3.types import Wei
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
|
||||
|
||||
@ -268,24 +268,24 @@ def cli(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ethereum address in hex encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
return ethereum.get_address(client, address_n, show_display, network, chunkify)
|
||||
return ethereum.get_address(session, address_n, show_display, network, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
||||
@with_session
|
||||
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
|
||||
"""Get Ethereum public node of given path."""
|
||||
address_n = tools.parse_path(address)
|
||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
||||
result = ethereum.get_public_node(session, address_n, show_display=show_display)
|
||||
return {
|
||||
"node": {
|
||||
"depth": result.node.depth,
|
||||
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("to_address")
|
||||
@click.argument("amount", callback=_amount_to_int)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
chain_id: int,
|
||||
address: str,
|
||||
amount: int,
|
||||
@ -400,7 +400,7 @@ def sign_tx(
|
||||
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
|
||||
address_n = tools.parse_path(address)
|
||||
from_address = ethereum.get_address(
|
||||
client, address_n, encoded_network=encoded_network
|
||||
session, address_n, encoded_network=encoded_network
|
||||
)
|
||||
|
||||
if token:
|
||||
@ -446,7 +446,7 @@ def sign_tx(
|
||||
assert max_gas_fee is not None
|
||||
assert max_priority_fee is not None
|
||||
sig = ethereum.sign_tx_eip1559(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
nonce=nonce,
|
||||
gas_limit=gas_limit,
|
||||
@ -465,7 +465,7 @@ def sign_tx(
|
||||
gas_price = _get_web3().eth.gas_price
|
||||
assert gas_price is not None
|
||||
sig = ethereum.sign_tx(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
tx_type=tx_type,
|
||||
nonce=nonce,
|
||||
@ -526,14 +526,14 @@ def sign_tx(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient", address: str, message: str, chunkify: bool
|
||||
session: "Session", address: str, message: str, chunkify: bool
|
||||
) -> Dict[str, str]:
|
||||
"""Sign message with Ethereum address."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
|
||||
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
|
||||
output = {
|
||||
"message": message,
|
||||
"address": ret.address,
|
||||
@ -550,9 +550,9 @@ def sign_message(
|
||||
help="Be compatible with Metamask's signTypedData_v4 implementation",
|
||||
)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
) -> Dict[str, str]:
|
||||
"""Sign typed data (EIP-712) with Ethereum address.
|
||||
|
||||
@ -565,7 +565,7 @@ def sign_typed_data(
|
||||
defs = EthereumDefinitions(encoded_network=network)
|
||||
data = json.loads(file.read())
|
||||
ret = ethereum.sign_typed_data(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
data,
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
@ -583,9 +583,9 @@ def sign_typed_data(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: str,
|
||||
message: str,
|
||||
@ -594,7 +594,7 @@ def verify_message(
|
||||
"""Verify message signed with Ethereum address."""
|
||||
signature_bytes = ethereum.decode_hex(signature)
|
||||
return ethereum.verify_message(
|
||||
client, address, signature_bytes, message, chunkify=chunkify
|
||||
session, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
@ -602,9 +602,9 @@ def verify_message(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.argument("domain_hash_hex")
|
||||
@click.argument("message_hash_hex")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Sign hash of typed data (EIP-712) with Ethereum address.
|
||||
@ -618,7 +618,7 @@ def sign_typed_data_hash(
|
||||
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_typed_data_hash(
|
||||
client, address_n, domain_hash, message_hash, network
|
||||
session, address_n, domain_hash, message_hash, network
|
||||
)
|
||||
output = {
|
||||
"domain_hash": domain_hash_hex,
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import fido
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||
|
||||
@ -40,10 +40,10 @@ def credentials() -> None:
|
||||
|
||||
|
||||
@credentials.command(name="list")
|
||||
@with_client
|
||||
def credentials_list(client: "TrezorClient") -> None:
|
||||
@with_session
|
||||
def credentials_list(session: "Session") -> None:
|
||||
"""List all resident credentials on the device."""
|
||||
creds = fido.list_credentials(client)
|
||||
creds = fido.list_credentials(session)
|
||||
for cred in creds:
|
||||
click.echo("")
|
||||
click.echo(f"WebAuthn credential at index {cred.index}:")
|
||||
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
|
||||
|
||||
@credentials.command(name="add")
|
||||
@click.argument("hex_credential_id")
|
||||
@with_client
|
||||
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
|
||||
@with_session
|
||||
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||
"""Add the credential with the given ID as a resident credential.
|
||||
|
||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||
"""
|
||||
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
|
||||
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
|
||||
|
||||
|
||||
@credentials.command(name="remove")
|
||||
@click.option(
|
||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||
)
|
||||
@with_client
|
||||
def credentials_remove(client: "TrezorClient", index: int) -> str:
|
||||
@with_session
|
||||
def credentials_remove(session: "Session", index: int) -> str:
|
||||
"""Remove the resident credential at the given index."""
|
||||
return fido.remove_credential(client, index)
|
||||
return fido.remove_credential(session, index)
|
||||
|
||||
|
||||
#
|
||||
@ -110,19 +110,19 @@ def counter() -> None:
|
||||
|
||||
@counter.command(name="set")
|
||||
@click.argument("counter", type=int)
|
||||
@with_client
|
||||
def counter_set(client: "TrezorClient", counter: int) -> str:
|
||||
@with_session
|
||||
def counter_set(session: "Session", counter: int) -> str:
|
||||
"""Set FIDO/U2F counter value."""
|
||||
return fido.set_counter(client, counter)
|
||||
return fido.set_counter(session, counter)
|
||||
|
||||
|
||||
@counter.command(name="get-next")
|
||||
@with_client
|
||||
def counter_get_next(client: "TrezorClient") -> int:
|
||||
@with_session
|
||||
def counter_get_next(session: "Session") -> int:
|
||||
"""Get-and-increase value of FIDO/U2F counter.
|
||||
|
||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||
is returned and atomically increased. This command performs the same operation
|
||||
and returns the counter value.
|
||||
"""
|
||||
return fido.get_next_counter(client)
|
||||
return fido.get_next_counter(session)
|
||||
|
@ -37,11 +37,12 @@ import requests
|
||||
from .. import device, exceptions, firmware, messages, models
|
||||
from ..firmware import models as fw_models
|
||||
from ..models import TrezorModel
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_management_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from . import TrezorConnection
|
||||
from ..transport.session import Session
|
||||
from . import NewTrezorConnection
|
||||
|
||||
MODEL_CHOICE = ChoiceType(
|
||||
{
|
||||
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||
1.8.0 because that installs the appropriate bootloader.
|
||||
"""
|
||||
f = client.features
|
||||
version = (f.major_version, f.minor_version, f.patch_version)
|
||||
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
|
||||
features = client.features
|
||||
version = client.version
|
||||
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
|
||||
return bootloader_onev2
|
||||
|
||||
|
||||
@ -306,25 +307,26 @@ def find_best_firmware_version(
|
||||
If the specified version is not found, prints the closest available version
|
||||
(higher than the specified one, if existing).
|
||||
"""
|
||||
features = client.features
|
||||
model = client.model
|
||||
|
||||
if bitcoin_only is None:
|
||||
bitcoin_only = _should_use_bitcoin_only(client.features)
|
||||
bitcoin_only = _should_use_bitcoin_only(features)
|
||||
|
||||
def version_str(version: Iterable[int]) -> str:
|
||||
return ".".join(map(str, version))
|
||||
|
||||
f = client.features
|
||||
|
||||
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
|
||||
releases = get_all_firmware_releases(model, bitcoin_only, beta)
|
||||
highest_version = releases[0]["version"]
|
||||
|
||||
if version:
|
||||
want_version = [int(x) for x in version.split(".")]
|
||||
if len(want_version) != 3:
|
||||
click.echo("Please use the 'X.Y.Z' version format.")
|
||||
if want_version[0] != f.major_version:
|
||||
if want_version[0] != features.major_version:
|
||||
click.echo(
|
||||
f"Warning: Trezor {client.model.name} firmware version should be "
|
||||
f"{f.major_version}.X.Y (requested: {version})"
|
||||
f"Warning: Trezor {model.name} firmware version should be "
|
||||
f"{features.major_version}.X.Y (requested: {version})"
|
||||
)
|
||||
else:
|
||||
want_version = highest_version
|
||||
@ -359,8 +361,8 @@ def find_best_firmware_version(
|
||||
# to the newer one, in that case update to the minimal
|
||||
# compatible version first
|
||||
# Choosing the version key to compare based on (not) being in BL mode
|
||||
client_version = [f.major_version, f.minor_version, f.patch_version]
|
||||
if f.bootloader_mode:
|
||||
client_version = client.version
|
||||
if features.bootloader_mode:
|
||||
key_to_compare = "min_bootloader_version"
|
||||
else:
|
||||
key_to_compare = "min_firmware_version"
|
||||
@ -447,11 +449,11 @@ def extract_embedded_fw(
|
||||
|
||||
|
||||
def upload_firmware_into_device(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
firmware_data: bytes,
|
||||
) -> None:
|
||||
"""Perform the final act of loading the firmware into Trezor."""
|
||||
f = client.features
|
||||
f = session.features
|
||||
try:
|
||||
if f.major_version == 1 and f.firmware_present is not False:
|
||||
# Trezor One does not send ButtonRequest
|
||||
@ -461,7 +463,7 @@ def upload_firmware_into_device(
|
||||
with click.progressbar(
|
||||
label="Uploading", length=len(firmware_data), show_eta=False
|
||||
) as bar:
|
||||
firmware.update(client, firmware_data, bar.update)
|
||||
firmware.update(session, firmware_data, bar.update)
|
||||
except exceptions.Cancelled:
|
||||
click.echo("Update aborted on device.")
|
||||
except exceptions.TrezorException as e:
|
||||
@ -519,7 +521,7 @@ def cli() -> None:
|
||||
@click.pass_obj
|
||||
# fmt: on
|
||||
def verify(
|
||||
obj: "TrezorConnection",
|
||||
obj: "NewTrezorConnection",
|
||||
filename: BinaryIO,
|
||||
check_device: bool,
|
||||
fingerprint: Optional[str],
|
||||
@ -564,7 +566,7 @@ def verify(
|
||||
@click.pass_obj
|
||||
# fmt: on
|
||||
def download(
|
||||
obj: "TrezorConnection",
|
||||
obj: "NewTrezorConnection",
|
||||
output: Optional[BinaryIO],
|
||||
model: Optional[TrezorModel],
|
||||
version: Optional[str],
|
||||
@ -630,7 +632,7 @@ def download(
|
||||
# fmt: on
|
||||
@click.pass_obj
|
||||
def update(
|
||||
obj: "TrezorConnection",
|
||||
obj: "NewTrezorConnection",
|
||||
filename: Optional[BinaryIO],
|
||||
url: Optional[str],
|
||||
version: Optional[str],
|
||||
@ -654,6 +656,7 @@ def update(
|
||||
against data.trezor.io information, if available.
|
||||
"""
|
||||
with obj.client_context() as client:
|
||||
management_session = client.get_management_session()
|
||||
if sum(bool(x) for x in (filename, url, version)) > 1:
|
||||
click.echo("You can use only one of: filename, url, version.")
|
||||
sys.exit(1)
|
||||
@ -709,7 +712,7 @@ def update(
|
||||
if _is_strict_update(client, firmware_data):
|
||||
header_size = _get_firmware_header_size(firmware_data)
|
||||
device.reboot_to_bootloader(
|
||||
client,
|
||||
management_session,
|
||||
boot_command=messages.BootCommand.INSTALL_UPGRADE,
|
||||
firmware_header=firmware_data[:header_size],
|
||||
language_data=language_data,
|
||||
@ -719,7 +722,7 @@ def update(
|
||||
click.echo(
|
||||
"WARNING: Seamless installation not possible, language data will not be uploaded."
|
||||
)
|
||||
device.reboot_to_bootloader(client)
|
||||
device.reboot_to_bootloader(management_session)
|
||||
|
||||
click.echo("Waiting for bootloader...")
|
||||
while True:
|
||||
@ -735,13 +738,15 @@ def update(
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
|
||||
upload_firmware_into_device(client=client, firmware_data=firmware_data)
|
||||
upload_firmware_into_device(
|
||||
session=client.get_management_session(), firmware_data=firmware_data
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("hex_challenge", required=False)
|
||||
@with_client
|
||||
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
|
||||
@with_management_session
|
||||
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
|
||||
"""Get a hash of the installed firmware combined with the optional challenge."""
|
||||
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
|
||||
return firmware.get_hash(client, challenge).hex()
|
||||
return firmware.get_hash(session, challenge).hex()
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
|
||||
import click
|
||||
|
||||
from .. import messages, monero, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
|
||||
|
||||
@ -42,9 +42,9 @@ def cli() -> None:
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
network_type: messages.MoneroNetworkType,
|
||||
@ -52,7 +52,7 @@ def get_address(
|
||||
) -> bytes:
|
||||
"""Get Monero address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return monero.get_address(client, address_n, show_display, network_type, chunkify)
|
||||
return monero.get_address(session, address_n, show_display, network_type, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_address(
|
||||
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_watch_key(
|
||||
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
|
||||
session: "Session", address: str, network_type: messages.MoneroNetworkType
|
||||
) -> Dict[str, str]:
|
||||
"""Get Monero watch key for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = monero.get_watch_key(client, address_n, network_type)
|
||||
res = monero.get_watch_key(session, address_n, network_type)
|
||||
# TODO: could be made required in MoneroWatchKey
|
||||
assert res.address is not None
|
||||
assert res.watch_key is not None
|
||||
|
@ -21,10 +21,10 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import nem, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
|
||||
|
||||
@ -39,9 +39,9 @@ def cli() -> None:
|
||||
@click.option("-N", "--network", type=int, default=0x68)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
network: int,
|
||||
show_display: bool,
|
||||
@ -49,7 +49,7 @@ def get_address(
|
||||
) -> str:
|
||||
"""Get NEM address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return nem.get_address(client, address_n, network, show_display, chunkify)
|
||||
return nem.get_address(session, address_n, network, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -58,9 +58,9 @@ def get_address(
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
file: TextIO,
|
||||
broadcast: Optional[str],
|
||||
@ -71,7 +71,7 @@ def sign_tx(
|
||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
||||
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}
|
||||
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import ripple, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
|
||||
|
||||
@ -37,13 +37,13 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ripple address"""
|
||||
address_n = tools.parse_path(address)
|
||||
return ripple.get_address(client, address_n, show_display, chunkify)
|
||||
return ripple.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -51,13 +51,13 @@ def get_address(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign Ripple transaction"""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||
|
||||
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
|
||||
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||
click.echo("Signature:")
|
||||
click.echo(result.signature.hex())
|
||||
click.echo()
|
||||
|
@ -24,10 +24,11 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import device, messages, toif
|
||||
from . import AliasedGroup, ChoiceType, with_client
|
||||
from ..transport.session import Session
|
||||
from . import AliasedGroup, ChoiceType, with_management_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
pass
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
@ -180,18 +181,18 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
||||
@with_management_session
|
||||
def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||
"""Set, change or remove PIN."""
|
||||
# Remove argument is there for backwards compatibility
|
||||
return device.change_pin(client, remove=_should_remove(enable, remove))
|
||||
return device.change_pin(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
||||
@with_management_session
|
||||
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||
"""Set or remove the wipe code.
|
||||
|
||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
|
||||
removed and the device will be reset to factory defaults.
|
||||
"""
|
||||
# Remove argument is there for backwards compatibility
|
||||
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
|
||||
return device.change_wipe_code(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
# keep the deprecated -l/--label option, make it do nothing
|
||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.argument("label")
|
||||
@with_client
|
||||
def label(client: "TrezorClient", label: str) -> str:
|
||||
@with_management_session
|
||||
def label(session: "Session", label: str) -> str:
|
||||
"""Set new device label."""
|
||||
return device.apply_settings(client, label=label)
|
||||
return device.apply_settings(session, label=label)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def brightness(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def brightness(session: "Session") -> str:
|
||||
"""Set display brightness."""
|
||||
return device.set_brightness(client)
|
||||
return device.set_brightness(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
||||
@with_management_session
|
||||
def haptic_feedback(session: "Session", enable: bool) -> str:
|
||||
"""Enable or disable haptic feedback."""
|
||||
return device.apply_settings(client, haptic_feedback=enable)
|
||||
return device.apply_settings(session, haptic_feedback=enable)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
||||
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
|
||||
)
|
||||
@click.option("-d/-D", "--display/--no-display", default=None)
|
||||
@with_client
|
||||
@with_management_session
|
||||
def language(
|
||||
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
|
||||
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
|
||||
) -> str:
|
||||
"""Set new language with translations."""
|
||||
if remove != (path_or_url is None):
|
||||
@ -260,29 +261,29 @@ def language(
|
||||
f"Failed to load translations from {path_or_url}"
|
||||
) from None
|
||||
return device.change_language(
|
||||
client, language_data=language_data, show_display=display
|
||||
session, language_data=language_data, show_display=display
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||
@with_client
|
||||
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
|
||||
@with_management_session
|
||||
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
|
||||
"""Set display rotation.
|
||||
|
||||
Configure display rotation for Trezor Model T. The options are
|
||||
north, east, south or west.
|
||||
"""
|
||||
return device.apply_settings(client, display_rotation=rotation)
|
||||
return device.apply_settings(session, display_rotation=rotation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("delay", type=str)
|
||||
@with_client
|
||||
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||
@with_management_session
|
||||
def auto_lock_delay(session: "Session", delay: str) -> str:
|
||||
"""Set auto-lock delay (in seconds)."""
|
||||
|
||||
if not client.features.pin_protection:
|
||||
if not session.features.pin_protection:
|
||||
raise click.ClickException("Set up a PIN first")
|
||||
|
||||
value, unit = delay[:-1], delay[-1:]
|
||||
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||
seconds = float(value) * units[unit]
|
||||
else:
|
||||
seconds = float(delay) # assume seconds if no unit is specified
|
||||
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
|
||||
return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("flags")
|
||||
@with_client
|
||||
def flags(client: "TrezorClient", flags: str) -> str:
|
||||
@with_management_session
|
||||
def flags(session: "Session", flags: str) -> str:
|
||||
"""Set device flags."""
|
||||
if flags.lower().startswith("0b"):
|
||||
flags_int = int(flags, 2)
|
||||
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
||||
flags_int = int(flags, 16)
|
||||
else:
|
||||
flags_int = int(flags)
|
||||
return device.apply_flags(client, flags=flags_int)
|
||||
return device.apply_flags(session, flags=flags_int)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||
)
|
||||
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
|
||||
@with_client
|
||||
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
@with_management_session
|
||||
def homescreen(session: "Session", filename: str, quality: int) -> str:
|
||||
"""Set new homescreen.
|
||||
|
||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
if not path.exists() or not path.is_file():
|
||||
raise click.ClickException("Cannot open file")
|
||||
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
img = image_to_t1(path)
|
||||
else:
|
||||
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 240
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 240
|
||||
)
|
||||
img = image_to_jpeg(path, width, height, quality)
|
||||
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = client.features.homescreen_width
|
||||
height = client.features.homescreen_height
|
||||
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = session.features.homescreen_width
|
||||
height = session.features.homescreen_height
|
||||
if width is None or height is None:
|
||||
raise click.ClickException("Device did not report homescreen size.")
|
||||
img = image_to_toif(path, width, height, True)
|
||||
elif (
|
||||
client.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or client.features.homescreen_format is None
|
||||
session.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or session.features.homescreen_format is None
|
||||
):
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 144
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 144
|
||||
)
|
||||
img = image_to_toif(path, width, height, False)
|
||||
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
"Unknown image format requested by the device."
|
||||
)
|
||||
|
||||
return device.apply_settings(client, homescreen=img)
|
||||
return device.apply_settings(session, homescreen=img)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
|
||||
)
|
||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||
@with_client
|
||||
@with_management_session
|
||||
def safety_checks(
|
||||
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
||||
session: "Session", always: bool, level: messages.SafetyCheckLevel
|
||||
) -> str:
|
||||
"""Set safety check level.
|
||||
|
||||
@ -392,18 +393,18 @@ def safety_checks(
|
||||
"""
|
||||
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
|
||||
level = messages.SafetyCheckLevel.PromptAlways
|
||||
return device.apply_settings(client, safety_checks=level)
|
||||
return device.apply_settings(session, safety_checks=level)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def experimental_features(client: "TrezorClient", enable: bool) -> str:
|
||||
@with_management_session
|
||||
def experimental_features(session: "Session", enable: bool) -> str:
|
||||
"""Enable or disable experimental message types.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
return device.apply_settings(client, experimental_features=enable)
|
||||
return device.apply_settings(session, experimental_features=enable)
|
||||
|
||||
|
||||
#
|
||||
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
|
||||
|
||||
@passphrase.command(name="on")
|
||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||
@with_client
|
||||
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
|
||||
@with_management_session
|
||||
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
|
||||
"""Enable passphrase."""
|
||||
if client.features.passphrase_protection is not True:
|
||||
if session.features.passphrase_protection is not True:
|
||||
use_passphrase = True
|
||||
else:
|
||||
use_passphrase = None
|
||||
return device.apply_settings(
|
||||
client,
|
||||
session,
|
||||
use_passphrase=use_passphrase,
|
||||
passphrase_always_on_device=force_on_device,
|
||||
)
|
||||
|
||||
|
||||
@passphrase.command(name="off")
|
||||
@with_client
|
||||
def passphrase_off(client: "TrezorClient") -> str:
|
||||
@with_management_session
|
||||
def passphrase_off(session: "Session") -> str:
|
||||
"""Disable passphrase."""
|
||||
return device.apply_settings(client, use_passphrase=False)
|
||||
return device.apply_settings(session, use_passphrase=False)
|
||||
|
||||
|
||||
# Registering the aliases for backwards compatibility
|
||||
@ -457,10 +458,10 @@ passphrase.aliases = {
|
||||
|
||||
@passphrase.command(name="hide")
|
||||
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
|
||||
@with_management_session
|
||||
def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
|
||||
"""Enable or disable hiding passphrase coming from host.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
return device.apply_settings(client, hide_passphrase_from_host=hide)
|
||||
return device.apply_settings(session, hide_passphrase_from_host=hide)
|
||||
|
@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import messages, solana, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
|
||||
DEFAULT_PATH = "m/44h/501h/0h/0h"
|
||||
@ -21,40 +21,40 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
) -> messages.SolanaPublicKey:
|
||||
"""Get Solana public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_public_key(client, address_n, show_display)
|
||||
return solana.get_public_key(session, address_n, show_display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
chunkify: bool,
|
||||
) -> messages.SolanaAddress:
|
||||
"""Get Solana address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_address(client, address_n, show_display, chunkify)
|
||||
return solana.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("serialized_tx", type=str)
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-a", "--additional-information-file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
serialized_tx: str,
|
||||
additional_information_file: Optional[TextIO],
|
||||
@ -78,7 +78,7 @@ def sign_tx(
|
||||
)
|
||||
|
||||
return solana.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
bytes.fromhex(serialized_tx),
|
||||
additional_information,
|
||||
|
@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import stellar, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
try:
|
||||
from stellar_sdk import (
|
||||
@ -52,13 +52,13 @@ def cli() -> None:
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Stellar public address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return stellar.get_address(client, address_n, show_display, chunkify)
|
||||
return stellar.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -77,9 +77,9 @@ def get_address(
|
||||
help="Network passphrase (blank for public network).",
|
||||
)
|
||||
@click.argument("b64envelope")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
||||
session: "Session", b64envelope: str, address: str, network_passphrase: str
|
||||
) -> bytes:
|
||||
"""Sign a base64-encoded transaction envelope.
|
||||
|
||||
@ -109,6 +109,6 @@ def sign_transaction(
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
tx, operations = stellar.from_envelope(envelope)
|
||||
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
|
||||
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
|
||||
|
||||
return base64.b64encode(resp.signature)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user