1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-25 06:40:58 +00:00

feat(core): implement THP

This commit is contained in:
M1nd3r 2024-12-02 15:41:37 +01:00
parent 6b69019b39
commit b1451282d0
65 changed files with 6161 additions and 506 deletions

View File

@ -51,6 +51,8 @@ storage.cache_codec
import storage.cache_codec import storage.cache_codec
storage.cache_common storage.cache_common
import storage.cache_common import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common storage.common
import storage.common import storage.common
storage.debug storage.debug
@ -419,10 +421,52 @@ apps.workflow_handlers
import apps.workflow_handlers import apps.workflow_handlers
if utils.USE_THP: if utils.USE_THP:
trezor.enums.ThpMessageType
import trezor.enums.ThpMessageType
trezor.enums.ThpPairingMethod
import trezor.enums.ThpPairingMethod
trezor.wire.thp
import trezor.wire.thp
trezor.wire.thp.alternating_bit_protocol
import trezor.wire.thp.alternating_bit_protocol
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.cpace
import trezor.wire.thp.cpace
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_main
import trezor.wire.thp.thp_main
trezor.wire.thp.transmission_loop
import trezor.wire.thp.transmission_loop
trezor.wire.thp.writer
import trezor.wire.thp.writer
apps.thp apps.thp
import apps.thp import apps.thp
apps.thp.create_new_session
import apps.thp.create_new_session
apps.thp.credential_manager apps.thp.credential_manager
import apps.thp.credential_manager import apps.thp.credential_manager
apps.thp.pairing
import apps.thp.pairing
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
trezor.enums.BinanceOrderSide trezor.enums.BinanceOrderSide

View File

@ -204,33 +204,37 @@ def get_features() -> Features:
return f return f
async def handle_Initialize(msg: Initialize) -> Features: if not utils.USE_THP:
import storage.cache_codec as cache_codec
session_id = cache_codec.start_session(msg.session_id) async def handle_Initialize(msg: Initialize) -> Features:
import storage.cache_codec as cache_codec
if not utils.BITCOIN_ONLY: session_id = cache_codec.start_session(msg.session_id)
from storage.cache_common import APP_COMMON_DERIVE_CARDANO
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) if not utils.BITCOIN_ONLY:
have_seed = context.cache_is_set(APP_COMMON_SEED) from storage.cache_common import APP_COMMON_DERIVE_CARDANO
if (
have_seed
and msg.derive_cardano is not None
and msg.derive_cardano != bool(derive_cardano)
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
if not have_seed: derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) have_seed = context.cache_is_set(APP_COMMON_SEED)
if (
have_seed
and msg.derive_cardano is not None
and msg.derive_cardano != bool(derive_cardano)
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
features = get_features() if not have_seed:
features.session_id = session_id context.cache_set_bool(
return features APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
features = get_features()
features.session_id = session_id
return features
async def handle_GetFeatures(msg: GetFeatures) -> Features: async def handle_GetFeatures(msg: GetFeatures) -> Features:
@ -464,8 +468,9 @@ def boot() -> None:
MT = MessageType # local_cache_global MT = MessageType # local_cache_global
# Register workflow handlers # Register workflow handlers
if not utils.USE_THP:
workflow_handlers.register(MT.Initialize, handle_Initialize)
for msg_type, handler in [ for msg_type, handler in [
(MT.Initialize, handle_Initialize),
(MT.GetFeatures, handle_GetFeatures), (MT.GetFeatures, handle_GetFeatures),
(MT.Cancel, handle_Cancel), (MT.Cancel, handle_Cancel),
(MT.LockDevice, handle_LockDevice), (MT.LockDevice, handle_LockDevice),

View File

@ -6,7 +6,7 @@ from storage.cache_common import (
APP_CARDANO_ICARUS_TREZOR_SECRET, APP_CARDANO_ICARUS_TREZOR_SECRET,
APP_COMMON_DERIVE_CARDANO, APP_COMMON_DERIVE_CARDANO,
) )
from trezor import wire from trezor import utils, wire
from trezor.crypto import cardano from trezor.crypto import cardano
from trezor.wire import context from trezor.wire import context
@ -21,6 +21,7 @@ if TYPE_CHECKING:
from trezor import messages from trezor import messages
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
@ -116,7 +117,7 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None: def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized() assert device.is_initialized()
assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
@ -144,8 +145,7 @@ def derive_and_store_secrets(passphrase: str) -> None:
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire import context
from apps.common.seed import derive_and_store_roots
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
@ -164,10 +164,13 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
# _get_secret # _get_secret
secret = context.cache_get(cache_entry) secret = context.cache_get(cache_entry)
if secret is None: if not utils.USE_THP:
await derive_and_store_roots() if secret is None:
secret = context.cache_get(cache_entry) from apps.common.seed import derive_and_store_roots_legacy
assert secret is not None
await derive_and_store_roots_legacy()
secret = context.cache_get(cache_entry)
assert secret is not None
root = cardano.from_secret(secret) root = cardano.from_secret(secret)
return Keychain(root) return Keychain(root)

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
from trezor import wire from trezor import utils, wire
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.wire import context from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter from trezor.wire.message_handler import filters, remove_filter
@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None:
remove_filter(_repeated_backup_filter) remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( if utils.USE_THP:
MessageType.Initialize, _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
MessageType.BackupDevice, MessageType.BackupDevice,
MessageType.WipeDevice, MessageType.WipeDevice,
MessageType.Cancel, MessageType.Cancel,
) )
else:
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
MessageType.BackupDevice,
MessageType.WipeDevice,
MessageType.Cancel,
)
def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:

View File

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.wire import DataError from trezor.wire import DataError
@ -172,6 +173,9 @@ async def get_keychain(
) -> Keychain: ) -> Keychain:
from .seed import get_seed from .seed import get_seed
if not utils.USE_THP:
pass
# try to ask for passphrase here
seed = await get_seed() seed = await get_seed()
keychain = Keychain(seed, curve, schemas, slip21_namespaces) keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain return keychain

View File

@ -1,84 +1,122 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device import storage.device as storage_device
from trezor import utils
from trezor.wire import DataError from trezor.wire import DataError
_MAX_PASSPHRASE_LEN = const(50) _MAX_PASSPHRASE_LEN = const(50)
if TYPE_CHECKING:
from trezor.messages import ThpCreateNewSession
def is_enabled() -> bool: def is_enabled() -> bool:
return storage_device.is_passphrase_enabled() return storage_device.is_passphrase_enabled()
async def get() -> str: async def get_passphrase(msg: ThpCreateNewSession) -> str:
from trezor import workflow
if not is_enabled(): if not is_enabled():
return "" return ""
if msg.on_device or storage_device.get_passphrase_always_on_device():
passphrase = await _get_on_device()
else: else:
workflow.close_others() # request exclusive UI access passphrase = msg.passphrase or ""
if storage_device.get_passphrase_always_on_device(): if passphrase:
from trezor.ui.layouts import request_passphrase_on_device await _handle_displaying_passphrase_from_host(passphrase)
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN) if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
else: raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _request_on_host() -> str:
from trezor import TR
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host()
request = PassphraseRequest()
ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None:
raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
)
# non-empty passphrase
if passphrase:
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
info=False,
)
return passphrase return passphrase
async def _get_on_device() -> str:
from trezor import workflow
from trezor.ui.layouts import request_passphrase_on_device
workflow.close_others() # request exclusive UI access
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
return passphrase
async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
from trezor import TR
from trezor.ui.layouts import confirm_action, confirm_blob
# We want to hide the passphrase, or show it, according to settings.
if storage_device.get_hide_passphrase_from_host():
await confirm_action(
"passphrase_host1_hidden",
TR.passphrase__wallet,
description=TR.passphrase__from_host_not_shown,
prompt_screen=True,
prompt_title=TR.passphrase__access_wallet,
)
else:
await confirm_action(
"passphrase_host1",
TR.passphrase__wallet,
description=TR.passphrase__next_screen_will_show_passphrase,
verb=TR.buttons__continue,
)
await confirm_blob(
"passphrase_host2",
TR.passphrase__title_confirm,
passphrase,
)
if not utils.USE_THP:
async def get() -> str:
from trezor import workflow
if not is_enabled():
return ""
else:
workflow.close_others() # request exclusive UI access
if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
)
return passphrase
async def _request_on_host() -> str:
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host()
request = PassphraseRequest()
ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None:
raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
)
# non-empty passphrase
if passphrase:
await _handle_displaying_passphrase_from_host(passphrase)
return passphrase

View File

@ -5,14 +5,18 @@ from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPH
from trezor import utils from trezor import utils
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.wire import context from trezor.wire import context
from trezor.wire.context import get_context
from trezor.wire.errors import DataError
from apps.common import cache from apps.common import cache
from . import mnemonic from . import mnemonic
from .passphrase import get as get_passphrase from .passphrase import get_passphrase as get_passphrase
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
@ -22,6 +26,9 @@ if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO, APP_COMMON_DERIVE_CARDANO,
) )
if not utils.USE_THP:
from .passphrase import get as get_passphrase_legacy
class Slip21Node: class Slip21Node:
""" """
@ -54,51 +61,111 @@ class Slip21Node:
return Slip21Node(data=self.data) return Slip21Node(data=self.data)
if not utils.BITCOIN_ONLY: if utils.USE_THP:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots() -> None: async def get_seed() -> bytes: # type: ignore [Function declaration "get_seed" is obscured by a declaration of the same name]
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not context.cache_is_set(APP_COMMON_SEED)
need_cardano_secret = context.cache_get_bool(
APP_COMMON_DERIVE_CARDANO
) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
context.cache_set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = context.cache_get(APP_COMMON_SEED) common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None assert common_seed is not None
return common_seed return common_seed
else: if utils.BITCOIN_ONLY:
# === Bitcoin-only variant === # === Bitcoin_only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else. # We want to derive the normal seed ONLY
@cache.stored_async(APP_COMMON_SEED) async def derive_and_store_roots(
async def get_seed() -> bytes: ctx: Context, msg: ThpCreateNewSession
passphrase = await get_passphrase() ) -> None:
return mnemonic.get_seed(passphrase)
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
if ctx.cache.is_set(APP_COMMON_SEED):
raise Exception("Seed is already set!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
else:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together
async def derive_and_store_roots(
ctx: Context, msg: ThpCreateNewSession
) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
raise Exception("Cardano icarus secret is already set!")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if msg.derive_cardano:
from apps.cardano.seed import derive_and_store_secrets
ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True)
derive_and_store_secrets(ctx, passphrase)
else:
if utils.BITCOIN_ONLY:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase_legacy()
return mnemonic.get_seed(passphrase=passphrase)
else:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots_legacy()
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
async def derive_and_store_roots_legacy() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
ctx = get_context()
need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
need_cardano_secret = ctx.cache.get_bool(
APP_COMMON_DERIVE_CARDANO
) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase_legacy()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(ctx, passphrase)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)

View File

@ -1,3 +1,5 @@
from trezor.wire import message_handler
if not __debug__: if not __debug__:
from trezor.utils import halt from trezor.utils import halt
@ -29,13 +31,14 @@ if __debug__:
DebugLinkState, DebugLinkState,
) )
from trezor.ui import Layout from trezor.ui import Layout
from trezor.wire import WireInterface, context from trezor.wire import WireInterface
from trezor.wire.protocol_common import Context
Handler = Callable[[Any], Awaitable[Any]] Handler = Callable[[Any], Awaitable[Any]]
layout_change_box = loop.mailbox() layout_change_box = loop.mailbox()
DEBUG_CONTEXT: context.Context | None = None DEBUG_CONTEXT: Context | None = None
REFRESH_INDEX = 0 REFRESH_INDEX = 0
@ -70,9 +73,7 @@ if __debug__:
"layout deadlock detected (did you send a ButtonAck?)" "layout deadlock detected (did you send a ButtonAck?)"
) )
async def return_layout_change( async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None:
ctx: wire.protocol_common.Context, detect_deadlock: bool = False
) -> None:
# set up the wait # set up the wait
storage.layout_watcher = True storage.layout_watcher = True
@ -212,12 +213,12 @@ if __debug__:
x = msg.x # local_cache_attribute x = msg.x # local_cache_attribute
y = msg.y # local_cache_attribute y = msg.y # local_cache_attribute
await wait_until_layout_is_running() await wait_until_layout_is_running()
assert isinstance(ui.CURRENT_LAYOUT, ui.Layout) assert isinstance(ui.CURRENT_LAYOUT, ui.Layout)
layout_change_box.clear() layout_change_box.clear()
try: try:
# click on specific coordinates, with possible hold # click on specific coordinates, with possible hold
if x is not None and y is not None: if x is not None and y is not None:
await _layout_click(x, y, msg.hold_ms or 0) await _layout_click(x, y, msg.hold_ms or 0)
@ -229,7 +230,11 @@ if __debug__:
elif msg.button is not None: elif msg.button is not None:
await _layout_event(msg.button) await _layout_event(msg.button)
elif msg.input is not None: elif msg.input is not None:
ui.CURRENT_LAYOUT._emit_message(msg.input) try:
ui.CURRENT_LAYOUT._emit_message(msg.input)
except Exception as e:
print(type(e))
else: else:
raise RuntimeError("Invalid DebugLinkDecision message") raise RuntimeError("Invalid DebugLinkDecision message")
@ -244,7 +249,11 @@ if __debug__:
# If no exception was raised, the layout did not shut down. That means that it # If no exception was raised, the layout did not shut down. That means that it
# just updated itself. The update is already live for the caller to retrieve. # just updated itself. The update is already live for the caller to retrieve.
def _state() -> DebugLinkState: def _state(
thp_pairing_code_entry_code: int | None = None,
thp_pairing_code_qr_code: bytes | None = None,
thp_pairing_code_nfc_unidirectional: bytes | None = None,
) -> DebugLinkState:
from trezor.messages import DebugLinkState from trezor.messages import DebugLinkState
from apps.common import mnemonic, passphrase from apps.common import mnemonic, passphrase
@ -263,13 +272,45 @@ if __debug__:
passphrase_protection=passphrase.is_enabled(), passphrase_protection=passphrase.is_enabled(),
reset_entropy=storage.reset_internal_entropy, reset_entropy=storage.reset_internal_entropy,
tokens=tokens, tokens=tokens,
thp_pairing_code_entry_code=thp_pairing_code_entry_code,
thp_pairing_code_qr_code=thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional,
) )
async def dispatch_DebugLinkGetState( async def dispatch_DebugLinkGetState(
msg: DebugLinkGetState, msg: DebugLinkGetState,
) -> DebugLinkState | None: ) -> DebugLinkState | None:
thp_pairing_code_entry_code: int | None = None
thp_pairing_code_qr_code: bytes | None = None
thp_pairing_code_nfc_unidirectional: bytes | None = None
if utils.USE_THP and msg.thp_channel_id is not None:
channel_id = int.from_bytes(msg.thp_channel_id, "big")
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_main import _CHANNELS
channel: Channel | None = None
ctx: PairingContext | None = None
try:
channel = _CHANNELS[channel_id]
ctx = channel.connection_context
except KeyError:
pass
if ctx is not None and isinstance(ctx, PairingContext):
thp_pairing_code_entry_code = ctx.display_data.code_code_entry
thp_pairing_code_qr_code = ctx.display_data.code_qr_code
thp_pairing_code_nfc_unidirectional = (
ctx.display_data.code_nfc_unidirectional
)
if msg.wait_layout == DebugWaitType.IMMEDIATE: if msg.wait_layout == DebugWaitType.IMMEDIATE:
return _state() return _state(
thp_pairing_code_entry_code,
thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional,
)
assert DEBUG_CONTEXT is not None assert DEBUG_CONTEXT is not None
if msg.wait_layout == DebugWaitType.NEXT_LAYOUT: if msg.wait_layout == DebugWaitType.NEXT_LAYOUT:
@ -280,7 +321,11 @@ if __debug__:
if not layout_is_ready(): if not layout_is_ready():
return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True) return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True)
else: else:
return _state() return _state(
thp_pairing_code_entry_code,
thp_pairing_code_qr_code,
thp_pairing_code_nfc_unidirectional,
)
async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success: async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
if msg.target_directory: if msg.target_directory:
@ -390,7 +435,6 @@ if __debug__:
ctx.iface.iface_num(), ctx.iface.iface_num(),
msg_type, msg_type,
) )
if msg.type not in WORKFLOW_HANDLERS: if msg.type not in WORKFLOW_HANDLERS:
await ctx.write(wire.message_handler.unexpected_message()) await ctx.write(wire.message_handler.unexpected_message())
continue continue
@ -403,7 +447,7 @@ if __debug__:
await ctx.write(Success()) await ctx.write(Success())
continue continue
req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type) req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
try: try:
res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg)
except Exception as exc: except Exception as exc:

View File

@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
boot_args = None boot_args = None
ctx = get_context() ctx = get_context()
await ctx.write(Success(message="Rebooting")) await ctx.write_force(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed # make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
# reboot to the bootloader, pass the firmware header hash if any # reboot to the bootloader, pass the firmware header hash if any

View File

@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
from trezor import TR, config, wire, workflow from trezor import TR, config, wire, workflow
from trezor.enums import BackupType, ButtonRequestType from trezor.enums import BackupType, ButtonRequestType
from trezor.ui.layouts import confirm_action, confirm_reset_device from trezor.ui.layouts import confirm_action, confirm_reset_device
from trezor.wire.context import try_get_ctx_ids
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.request_pin import ( from apps.common.request_pin import (
@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
if recovery_type == RecoveryType.NormalRecovery: if recovery_type == RecoveryType.NormalRecovery:
await confirm_reset_device(recovery=True) await confirm_reset_device(recovery=True)
# wipe storage to make sure the device is in a clear state # wipe storage to make sure the device is in a clear state (except protocol cache)
storage.reset() storage.reset(excluded=try_get_ctx_ids())
# set up pin if requested # set up pin if requested
if msg.pin_protection: if msg.pin_protection:

View File

@ -3,8 +3,9 @@ from typing import TYPE_CHECKING
import storage.device as storage_device import storage.device as storage_device
import storage.recovery as storage_recovery import storage.recovery as storage_recovery
import storage.recovery_shares as storage_recovery_shares import storage.recovery_shares as storage_recovery_shares
from trezor import TR, wire from trezor import TR, utils, wire
from trezor.messages import Success from trezor.messages import Success
from trezor.wire import message_handler
from apps.common import backup_types from apps.common import backup_types
@ -38,18 +39,26 @@ async def recovery_process() -> Success:
recovery_type = storage_recovery.get_type() recovery_type = storage_recovery.get_type()
wire.message_handler.AVOID_RESTARTING_FOR = ( if utils.USE_THP:
MessageType.Initialize, message_handler.AVOID_RESTARTING_FOR = (
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
) )
else:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
try: try:
return await _continue_recovery_process() return await _continue_recovery_process()
except recover.RecoveryAborted: except recover.RecoveryAborted:
storage_recovery.end_progress() storage_recovery.end_progress()
backup.deactivate_repeated_backup() backup.deactivate_repeated_backup()
if recovery_type == RecoveryType.NormalRecovery: if recovery_type == RecoveryType.NormalRecovery:
storage.wipe() from trezor.wire.context import try_get_ctx_ids
storage.wipe(excluded=try_get_ctx_ids())
raise wire.ActionCancelled raise wire.ActionCancelled
@ -59,11 +68,17 @@ async def _continue_repeated_backup() -> None:
from apps.common import backup from apps.common import backup
from apps.management.backup_device import perform_backup from apps.management.backup_device import perform_backup
wire.message_handler.AVOID_RESTARTING_FOR = ( if utils.USE_THP:
MessageType.Initialize, message_handler.AVOID_RESTARTING_FOR = (
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
) )
else:
message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
)
try: try:
await perform_backup(is_repeated_backup=True) await perform_backup(is_repeated_backup=True)

View File

@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success:
prompt_backup, prompt_backup,
show_wallet_created_success, show_wallet_created_success,
) )
from trezor.wire.context import call from trezor.wire.context import call, try_get_ctx_ids
from apps.common.request_pin import request_pin_confirm from apps.common.request_pin import request_pin_confirm
@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success:
# Rendering empty loader so users do not feel a freezing screen # Rendering empty loader so users do not feel a freezing screen
render_empty_loader(config.StorageMessage.PROCESSING_MSG) render_empty_loader(config.StorageMessage.PROCESSING_MSG)
# wipe storage to make sure the device is in a clear state # wipe storage to make sure the device is in a clear state (except protocol cache)
storage.reset() storage.reset(excluded=try_get_ctx_ids())
# Check backup type, perform type-specific handling # Check backup type, perform type-specific handling
if backup_types.is_slip39_backup_type(backup_type): if backup_types.is_slip39_backup_type(backup_type):
@ -139,7 +139,7 @@ async def reset_device(msg: ResetDevice) -> Success:
if perform_backup: if perform_backup:
await layout.show_backup_success() await layout.show_backup_success()
return Success(message="Initialized") return Success(message="Initialized") # TODO: Why "Initialized?"
async def _entropy_check(secret: bytes) -> bool: async def _entropy_check(secret: bytes) -> bool:

View File

@ -1,12 +1,19 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.wire.context import get_context, try_get_ctx_ids
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import Success, WipeDevice from typing import NoReturn
from trezor.messages import WipeDevice
if __debug__:
from trezor import log
async def wipe_device(msg: WipeDevice) -> Success: async def wipe_device(msg: WipeDevice) -> NoReturn:
import storage import storage
from trezor import TR, config, translations from trezor import TR, config, loop, translations
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.messages import Success from trezor.messages import Success
from trezor.pin import render_empty_loader from trezor.pin import render_empty_loader
@ -26,16 +33,22 @@ async def wipe_device(msg: WipeDevice) -> Success:
br_code=ButtonRequestType.WipeDevice, br_code=ButtonRequestType.WipeDevice,
) )
if __debug__:
log.debug(__name__, "Device wipe - start")
# start an empty progress screen so that the screen is not blank while waiting # start an empty progress screen so that the screen is not blank while waiting
render_empty_loader(config.StorageMessage.PROCESSING_MSG) render_empty_loader(config.StorageMessage.PROCESSING_MSG)
# wipe storage # wipe storage
storage.wipe() storage.wipe(excluded=try_get_ctx_ids())
# erase translations # erase translations
translations.deinit() translations.deinit()
translations.erase() translations.erase()
await get_context().write_force(Success(message="Device wiped"))
storage.wipe_cache()
# reload settings # reload settings
reload_settings_from_storage() reload_settings_from_storage()
loop.clear()
return Success(message="Device wiped") if __debug__:
log.debug(__name__, "Device wipe - finished")

View File

@ -0,0 +1,59 @@
from trezor import log, loop
from trezor.enums import FailureType
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.wire.context import get_context
from trezor.wire.errors import ActionCancelled, DataError
from trezor.wire.thp import SessionState
async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
"""
Creates a new `ThpSession` based on the provided parameters and returns a
`ThpNewSession` message containing the new session ID.
Returns an appropriate `Failure` message if session creation fails.
"""
from trezor.wire import NotInitialized
from trezor.wire.thp.session_context import GenericSessionContext
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
ctx = get_context()
# Assert that context `ctx` is `GenericSessionContext`
assert isinstance(ctx, GenericSessionContext)
channel = ctx.channel
# Do not use `ctx` beyond this point, as it is techically
# allowed to change in between await statements
new_session = create_new_session(channel)
try:
await derive_and_store_roots(new_session, message)
except DataError as e:
return Failure(code=FailureType.DataError, message=e.message)
except ActionCancelled as e:
return Failure(code=FailureType.ActionCancelled, message=e.message)
except NotInitialized as e:
return Failure(code=FailureType.NotInitialized, message=e.message)
# TODO handle other errors (`Exception`` when "Cardano icarus secret is already set!"
# and `RuntimeError` when accessing storage for mnemonic.get_secret - it actually
# happens for locked devices)
new_session.set_session_state(SessionState.ALLOCATED)
channel.sessions[new_session.session_id] = new_session
loop.schedule(new_session.handle())
new_session_id: int = new_session.session_id
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
message.passphrase if message.passphrase is not None else "",
new_session.session_id,
str(channel.sessions),
)
return ThpNewSession(new_session_id=new_session_id)

View File

@ -0,0 +1,403 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import loop, protobuf
from trezor.crypto.hashlib import sha256
from trezor.enums import ThpMessageType, ThpPairingMethod
from trezor.messages import (
Cancel,
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHost,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpCredentialMetadata,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnidirectionalSecret,
ThpNfcUnidirectionalTag,
ThpPairingPreparationsFinished,
ThpQrCodeSecret,
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto
from trezor.wire.thp.pairing_context import PairingContext
from .credential_manager import issue_credential
if __debug__:
from trezor import log
if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
#
# Helpers - decorators
def check_state_and_log(
*allowed_states: ChannelState,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_state(context, *allowed_states)
if __debug__:
try:
log.debug(__name__, "started %s", f.__name__)
except AttributeError:
log.debug(
__name__,
"started a function that cannot be named, because it raises AttributeError, eg. closure",
)
return f(context, *args, **kwargs)
return inner
return decorator
def check_method_is_allowed(
pairing_method: ThpPairingMethod,
) -> Callable[[FuncWithContext], FuncWithContext]:
def decorator(f: FuncWithContext) -> FuncWithContext:
def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
_check_method_is_allowed(context, pairing_method)
return f(context, *args, **kwargs)
return inner
return decorator
#
# Pairing handlers
@check_state_and_log(ChannelState.TP1)
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpStartPairingRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
ctx.host_name = message.host_name or ""
skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
if skip_pairing:
return await _end_pairing(ctx)
await _prepare_pairing(ctx)
await ctx.write(ThpPairingPreparationsFinished())
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await show_display_data(
ctx, _get_possible_pairing_methods_and_cancel(ctx)
)
if Cancel.is_type_of(response):
ctx.channel_ctx.clear()
raise SilentError("Action was cancelled by the Host")
# TODO disable NFC (if enabled)
response = await _handle_different_pairing_methods(ctx, response)
while ThpCredentialRequest.is_type_of(response):
response = await _handle_credential_request(ctx, response)
return await _handle_end_request(ctx, response)
async def _prepare_pairing(ctx: PairingContext) -> None:
if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
await _handle_code_entry_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.QrCode):
_handle_qr_code_is_included(ctx)
if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
_handle_nfc_unidirectional_is_included(ctx)
async def show_display_data(
ctx: PairingContext, expected_types: Container[int] = ()
) -> type[protobuf.MessageType]:
from trezorui_api import CANCELLED
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task.get_result())
result: type[protobuf.MessageType] = await race
if result is CANCELLED:
raise ActionCancelled
return result
@check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
commitment = sha256(ctx.secret).digest()
challenge_message = await ctx.call( # noqa: F841
ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
)
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
if not ThpCodeEntryChallenge.is_type_of(challenge_message):
raise UnexpectedMessage("Unexpected message")
if challenge_message.challenge is None:
raise Exception("Invalid message")
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(challenge_message.challenge)
sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8"))
code_code_entry_hash = sha_ctx.digest()
ctx.display_data.code_code_entry = (
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_qr_code_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8"))
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.secret)
sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8"))
ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16]
@check_state_and_log(ChannelState.TP3)
async def _handle_different_pairing_methods(
ctx: PairingContext, response: protobuf.MessageType
) -> protobuf.MessageType:
if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise UnexpectedMessage("Unexpected message")
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
from trezor.wire.thp.cpace import Cpace
# TODO check that ThpCodeEntryCpaceHost message is valid
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryCpaceHost)
if message.cpace_host_public_key is None:
raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
ctx.cpace = Cpace(
message.cpace_host_public_key,
ctx.channel_ctx.get_handshake_hash(),
)
assert ctx.display_data.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
ThpCodeEntryTag,
)
return await _handle_code_entry_tag(ctx, response)
@check_state_and_log(ChannelState.TP4)
@check_method_is_allowed(ThpPairingMethod.CodeEntry)
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpCodeEntryTag)
expected_tag = sha256(ctx.cpace.shared_secret).digest()
if expected_tag != message.tag:
print(
"expected code entry tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code entry shared secret:",
hexlify(ctx.cpace.shared_secret).decode(),
) # TODO remove after testing
raise ThpError("Unexpected Code Entry Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpCodeEntrySecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.QrCode)
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag)
assert ctx.display_data.code_qr_code is not None
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
if expected_tag != message.tag:
print(
"expected qr code tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
print(
"expected code qr code tag:",
hexlify(ctx.display_data.code_qr_code).decode(),
) # TODO remove after testing
print(
"expected secret:", hexlify(ctx.secret).decode()
) # TODO remove after testing
raise ThpError("Unexpected QR Code Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpQrCodeSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3)
@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpNfcUnidirectionalTag)
expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
if expected_tag != message.tag:
print(
"expected nfc tag:", hexlify(expected_tag).decode()
) # TODO remove after testing
raise ThpError("Unexpected NFC Unidirectional Tag")
return await _handle_secret_reveal(
ctx,
msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
)
@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
async def _handle_secret_reveal(
ctx: PairingContext,
msg: protobuf.MessageType,
) -> protobuf.MessageType:
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
return await ctx.call_any(
msg,
ThpMessageType.ThpCredentialRequest,
ThpMessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> protobuf.MessageType:
ctx.secret
if not ThpCredentialRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
if message.host_static_pubkey is None:
raise Exception("Invalid message") # TODO change failure type
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
return await ctx.call_any(
ThpCredentialResponse(
trezor_static_pubkey=trezor_static_pubkey, credential=credential
),
ThpMessageType.ThpCredentialRequest,
ThpMessageType.ThpEndRequest,
)
@check_state_and_log(ChannelState.TC1)
async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
if not ThpEndRequest.is_type_of(message):
raise UnexpectedMessage("Unexpected message")
return await _end_pairing(ctx)
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
#
# Helpers - checkers
def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
if ctx.channel_ctx.get_channel_state() not in allowed_states:
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel_ctx.selected_pairing_methods
#
# Helpers - getters
def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]:
r = _get_possible_pairing_methods(ctx)
mtype = Cancel.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
r = tuple(
_get_message_type_for_method(method)
for method in ctx.channel_ctx.selected_pairing_methods
)
if __debug__:
from trezor.messages import DebugLinkGetState
mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE
return r + ((mtype,) if mtype is not None else ())
return r
def _get_message_type_for_method(method: int) -> int:
if method is ThpPairingMethod.CodeEntry:
return ThpMessageType.ThpCodeEntryCpaceHost
if method is ThpPairingMethod.NFC_Unidirectional:
return ThpMessageType.ThpNfcUnidirectionalTag
if method is ThpPairingMethod.QrCode:
return ThpMessageType.ThpQrCodeTag
raise ValueError("Unexpected pairing method - no message type available")

View File

@ -35,6 +35,13 @@ def _find_message_handler_module(msg_type: int) -> str:
if __debug__ and msg_type == MessageType.BenchmarkRun: if __debug__ and msg_type == MessageType.BenchmarkRun:
return "apps.benchmark.run" return "apps.benchmark.run"
if utils.USE_THP:
from trezor.enums import ThpMessageType
# thp management
if msg_type == ThpMessageType.ThpCreateNewSession:
return "apps.thp.create_new_session"
# management # management
if msg_type == MessageType.ResetDevice: if msg_type == MessageType.ResetDevice:
return "apps.management.reset_device" return "apps.management.reset_device"

View File

@ -1,11 +1,27 @@
# make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module # make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module
from typing import TYPE_CHECKING
from storage import cache, common, device from storage import cache, common, device
if TYPE_CHECKING:
from typing import Tuple
def wipe() -> None: pass
def wipe(excluded: Tuple[bytes, bytes] | None) -> None:
"""
TODO REPHRASE SO THAT IT IS TRUE! Wipes the storage. Using `exclude_protocol=False` destroys the THP communication channel.
If the device should communicate after wipe, use `exclude_protocol=True` and clear cache manually later using
`wipe_cache()`.
"""
from trezor import config from trezor import config
config.wipe() config.wipe()
cache.clear_all(excluded)
def wipe_cache() -> None:
cache.clear_all() cache.clear_all()
@ -21,12 +37,12 @@ def init_unlocked() -> None:
common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True) common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True)
def reset() -> None: def reset(excluded: Tuple[bytes, bytes] | None) -> None:
""" """
Wipes storage but keeps the device id unchanged. Wipes storage but keeps the device id unchanged.
""" """
device_id = device.get_device_id() device_id = device.get_device_id()
wipe() wipe(excluded)
common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True) common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True)

View File

@ -1,26 +1,47 @@
import builtins import builtins
import gc import gc
from typing import TYPE_CHECKING
from storage import cache_codec
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from trezor import utils
if TYPE_CHECKING:
from typing import Tuple
pass
# Cache initialization # Cache initialization
_SESSIONLESS_CACHE = SessionlessCache() _SESSIONLESS_CACHE = SessionlessCache()
_PROTOCOL_CACHE = cache_codec
if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize() _PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
gc.collect() gc.collect()
def clear_all() -> None: def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
""" """
Clears all data from both the protocol cache and the sessionless cache. Clears all data from both the protocol cache and the sessionless cache.
""" """
global autolock_last_touch global autolock_last_touch
autolock_last_touch = None autolock_last_touch = None
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
if utils.USE_THP and excluded is not None:
# If we want to keep THP connection alive, we do not clear communication keys
cache_thp.clear_all_except_one_session_keys(excluded)
else:
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]: def get_int_all_sessions(key: int) -> builtins.set[int]:

View File

@ -14,6 +14,14 @@ if not utils.BITCOIN_ONLY:
APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
APP_MONERO_LIVE_REFRESH = const(7) APP_MONERO_LIVE_REFRESH = const(7)
# Cache keys for THP channel
if utils.USE_THP:
CHANNEL_HANDSHAKE_HASH = const(0)
CHANNEL_KEY_RECEIVE = const(1)
CHANNEL_KEY_SEND = const(2)
CHANNEL_NONCE_RECEIVE = const(3)
CHANNEL_NONCE_SEND = const(4)
# Keys that are valid across sessions # Keys that are valid across sessions
SESSIONLESS_FLAG = const(128) SESSIONLESS_FLAG = const(128)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)

View File

@ -0,0 +1,363 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache
if TYPE_CHECKING:
from typing import Tuple
pass
# THP specific constants
_MAX_CHANNELS_COUNT = const(10)
_MAX_SESSIONS_COUNT = const(20)
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2)
SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(0xFFFF)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
_MANAGEMENT_STATE = const(2)
MANAGEMENT_SESSION_ID = const(0)
class ThpDataCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.last_usage = 0
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
self.last_usage = 0
super().clear()
class ChannelCache(ThpDataCache):
def __init__(self) -> None:
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x00
self.fields = (
32, # CHANNEL_HANDSHAKE_HASH
32, # CHANNEL_KEY_RECEIVE
32, # CHANNEL_KEY_SEND
8, # CHANNEL_NONCE_RECEIVE
8, # CHANNEL_NONCE_SEND
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
) # Set state to UNALLOCATED
self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH)
self.state[:] = bytearray(_CHANNEL_STATE_LENGTH)
self.iface[:] = bytearray(1)
super().clear()
class SessionThpCache(ThpDataCache):
def __init__(self) -> None:
from trezor import utils
self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
0, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
0, # APP_MONERO_LIVE_REFRESH
)
super().__init__()
def clear(self) -> None:
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.session_id[:] = b""
super().clear()
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
cid_counter: int = 0
# Last-used counter
_usage_counter = 0
def initialize() -> None:
global _CHANNELS
global _SESSIONS
global cid_counter
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
from trezorcrypto import random
cid_counter = random.uniform(0xFFFE)
def get_new_channel(iface: bytes) -> ChannelCache:
if len(iface) != _WIRE_INTERFACE_LENGTH:
raise Exception("Invalid WireInterface (encoded) length")
new_cid = get_next_channel_id()
index = _get_next_channel_index()
# clear sessions from replaced channel
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
old_cid = _CHANNELS[index].channel_id
clear_sessions_with_channel_id(old_cid)
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface[:] = bytearray(iface)
return _CHANNELS[index]
def update_channel_last_used(channel_id: bytes) -> None:
for channel in _CHANNELS:
if channel.channel_id == channel_id:
channel.last_usage = _get_usage_counter_and_increment()
return
def update_session_last_used(channel_id: bytes, session_id: bytes) -> None:
for session in _SESSIONS:
if session.channel_id == channel_id and session.session_id == session_id:
session.last_usage = _get_usage_counter_and_increment()
update_channel_last_used(channel_id)
return
def get_all_allocated_channels() -> list[ChannelCache]:
_list: list[ChannelCache] = []
for channel in _CHANNELS:
if _get_channel_state(channel) != _UNALLOCATED_STATE:
_list.append(channel)
return _list
def get_allocated_session(
channel_id: bytes, session_id: bytes
) -> SessionThpCache | None:
"""
Finds and returns the first allocated session matching the given `channel_id` and `session_id`,
or `None` if no match is found.
Raises `Exception` if either channel_id or session_id has an invalid length.
"""
if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH:
raise Exception("At least one of arguments has invalid length")
for session in _SESSIONS:
if _get_session_state(session) == _UNALLOCATED_STATE:
continue
if session.channel_id != channel_id:
continue
if session.session_id != session_id:
continue
return session
return None
def is_management_session(session_cache: SessionThpCache) -> bool:
return _get_session_state(session_cache) == _MANAGEMENT_STATE
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
if len(key) != KEY_LENGTH:
raise Exception("Invalid key length")
channel.host_ephemeral_pubkey = key
def get_new_session(channel: ChannelCache) -> SessionThpCache:
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
_SESSIONS[index] = SessionThpCache()
_SESSIONS[index].channel_id[:] = channel.channel_id
_SESSIONS[index].session_id[:] = new_sid
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
channel.last_usage = (
_get_usage_counter_and_increment()
) # increment also use of the channel so it does not get replaced
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
return _SESSIONS[index]
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
return _usage_counter
def _get_next_channel_index() -> int:
idx = _get_unallocated_channel_index()
if idx is not None:
return idx
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_next_session_index() -> int:
idx = _get_unallocated_session_index()
if idx is not None:
return idx
return get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_unallocated_session_index() -> int | None:
for i in range(_MAX_SESSIONS_COUNT):
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
return int.from_bytes(channel.state, "big")
def _get_session_state(session: SessionThpCache) -> int:
return int.from_bytes(session.state, "big")
def get_next_channel_id() -> bytes:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def get_next_session_id(channel: ChannelCache) -> bytes:
while True:
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
for session in _SESSIONS:
if session.channel_id == channel.channel_id:
if session.session_id == channel.session_id_counter:
return False
return True
def _is_cid_unique() -> bool:
global cid_counter
cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
for channel in _CHANNELS:
if channel.channel_id == cid_counter_bytes:
return False
return True
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
) -> int:
global _usage_counter
lru_counter = _usage_counter + 1
lru_item_index = 0
for i in range(max_count):
if list[i].last_usage < lru_counter:
lru_counter = list[i].last_usage
lru_item_index = i
return lru_item_index
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_sessions_with_channel_id(channel_id: bytes) -> None:
for session in _SESSIONS:
if session.channel_id == channel_id:
session.clear()
def clear_session(session: SessionThpCache) -> None:
for s in _SESSIONS:
if s.channel_id == session.channel_id and s.session_id == session.session_id:
session.clear()
def clear_all() -> None:
for session in _SESSIONS:
session.clear()
for channel in _CHANNELS:
channel.clear()
def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
cid, sid = excluded
for channel in _CHANNELS:
if channel.channel_id != cid:
channel.clear()
for session in _SESSIONS:
if session.channel_id != cid and session.session_id != sid:
session.clear()
else:
s_last_usage = session.last_usage
session.clear()
session.last_usage = s_last_usage
session.state = bytearray(_MANAGEMENT_STATE.to_bytes(1, "big"))
session.session_id[:] = bytearray(sid)
session.channel_id[:] = bytearray(cid)

View File

@ -16,4 +16,6 @@ NotInitialized = 11
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99 FirmwareError = 99

22
core/src/trezor/enums/ThpMessageType.py generated Normal file
View File

@ -0,0 +1,22 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033

View File

@ -0,0 +1,8 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4

View File

@ -39,6 +39,8 @@ if TYPE_CHECKING:
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99 FirmwareError = 99
class ButtonRequestType(IntEnum): class ButtonRequestType(IntEnum):
@ -347,6 +349,32 @@ if TYPE_CHECKING:
Nay = 1 Nay = 1
Pass = 2 Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4
class MessageType(IntEnum): class MessageType(IntEnum):
Initialize = 0 Initialize = 0
Ping = 1 Ping = 1

View File

@ -68,6 +68,8 @@ if TYPE_CHECKING:
from trezor.enums import StellarSignerType # noqa: F401 from trezor.enums import StellarSignerType # noqa: F401
from trezor.enums import TezosBallotType # noqa: F401 from trezor.enums import TezosBallotType # noqa: F401
from trezor.enums import TezosContractType # noqa: F401 from trezor.enums import TezosContractType # noqa: F401
from trezor.enums import ThpMessageType # noqa: F401
from trezor.enums import ThpPairingMethod # noqa: F401
from trezor.enums import WordRequestType # noqa: F401 from trezor.enums import WordRequestType # noqa: F401
class BenchmarkListNames(protobuf.MessageType): class BenchmarkListNames(protobuf.MessageType):
@ -2898,11 +2900,13 @@ if TYPE_CHECKING:
class DebugLinkGetState(protobuf.MessageType): class DebugLinkGetState(protobuf.MessageType):
wait_layout: "DebugWaitType" wait_layout: "DebugWaitType"
thp_channel_id: "bytes | None"
def __init__( def __init__(
self, self,
*, *,
wait_layout: "DebugWaitType | None" = None, wait_layout: "DebugWaitType | None" = None,
thp_channel_id: "bytes | None" = None,
) -> None: ) -> None:
pass pass
@ -2924,6 +2928,9 @@ if TYPE_CHECKING:
reset_word_pos: "int | None" reset_word_pos: "int | None"
mnemonic_type: "BackupType | None" mnemonic_type: "BackupType | None"
tokens: "list[str]" tokens: "list[str]"
thp_pairing_code_entry_code: "int | None"
thp_pairing_code_qr_code: "bytes | None"
thp_pairing_code_nfc_unidirectional: "bytes | None"
def __init__( def __init__(
self, self,
@ -2941,6 +2948,9 @@ if TYPE_CHECKING:
recovery_word_pos: "int | None" = None, recovery_word_pos: "int | None" = None,
reset_word_pos: "int | None" = None, reset_word_pos: "int | None" = None,
mnemonic_type: "BackupType | None" = None, mnemonic_type: "BackupType | None" = None,
thp_pairing_code_entry_code: "int | None" = None,
thp_pairing_code_qr_code: "bytes | None" = None,
thp_pairing_code_nfc_unidirectional: "bytes | None" = None,
) -> None: ) -> None:
pass pass
@ -6162,6 +6172,278 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]: def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
return isinstance(msg, cls) return isinstance(msg, cls)
class ThpDeviceProperties(protobuf.MessageType):
internal_model: "str | None"
model_variant: "int | None"
bootloader_mode: "bool | None"
protocol_version: "int | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
internal_model: "str | None" = None,
model_variant: "int | None" = None,
bootloader_mode: "bool | None" = None,
protocol_version: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
return isinstance(msg, cls)
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
host_pairing_credential: "bytes | None"
pairing_methods: "list[ThpPairingMethod]"
def __init__(
self,
*,
pairing_methods: "list[ThpPairingMethod] | None" = None,
host_pairing_credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
return isinstance(msg, cls)
class ThpCreateNewSession(protobuf.MessageType):
passphrase: "str | None"
on_device: "bool | None"
derive_cardano: "bool | None"
def __init__(
self,
*,
passphrase: "str | None" = None,
on_device: "bool | None" = None,
derive_cardano: "bool | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
return isinstance(msg, cls)
class ThpNewSession(protobuf.MessageType):
new_session_id: "int | None"
def __init__(
self,
*,
new_session_id: "int | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
return isinstance(msg, cls)
class ThpStartPairingRequest(protobuf.MessageType):
host_name: "str | None"
def __init__(
self,
*,
host_name: "str | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
return isinstance(msg, cls)
class ThpPairingPreparationsFinished(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
return isinstance(msg, cls)
class ThpCodeEntryCommitment(protobuf.MessageType):
commitment: "bytes | None"
def __init__(
self,
*,
commitment: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
return isinstance(msg, cls)
class ThpCodeEntryChallenge(protobuf.MessageType):
challenge: "bytes | None"
def __init__(
self,
*,
challenge: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceHost(protobuf.MessageType):
cpace_host_public_key: "bytes | None"
def __init__(
self,
*,
cpace_host_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
return isinstance(msg, cls)
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
cpace_trezor_public_key: "bytes | None"
def __init__(
self,
*,
cpace_trezor_public_key: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
return isinstance(msg, cls)
class ThpCodeEntryTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
return isinstance(msg, cls)
class ThpCodeEntrySecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
return isinstance(msg, cls)
class ThpQrCodeTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
return isinstance(msg, cls)
class ThpQrCodeSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalTag(protobuf.MessageType):
tag: "bytes | None"
def __init__(
self,
*,
tag: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
return isinstance(msg, cls)
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
secret: "bytes | None"
def __init__(
self,
*,
secret: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
return isinstance(msg, cls)
class ThpCredentialRequest(protobuf.MessageType):
host_static_pubkey: "bytes | None"
def __init__(
self,
*,
host_static_pubkey: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
return isinstance(msg, cls)
class ThpCredentialResponse(protobuf.MessageType):
trezor_static_pubkey: "bytes | None"
credential: "bytes | None"
def __init__(
self,
*,
trezor_static_pubkey: "bytes | None" = None,
credential: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
return isinstance(msg, cls)
class ThpEndRequest(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
return isinstance(msg, cls)
class ThpEndResponse(protobuf.MessageType):
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
return isinstance(msg, cls)
class ThpCredentialMetadata(protobuf.MessageType): class ThpCredentialMetadata(protobuf.MessageType):
host_name: "str | None" host_name: "str | None"

View File

@ -33,6 +33,10 @@ from trezorutils import ( # noqa: F401
) )
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
DISABLE_ENCRYPTION: bool = False
ALLOW_DEBUG_MESSAGES: bool = True
if __debug__: if __debug__:
if EMULATOR: if EMULATOR:
import uos import uos

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response. - Request / response.
- Protobuf-encoded, see `protobuf.py`. - Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. - Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`.
- Transferred over USB interface, or UDP in case of Unix emulation. - Transferred over USB interface, or UDP in case of Unix emulation.
This module: This module:
@ -29,7 +29,12 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from . import message_handler, protocol_common from . import message_handler, protocol_common
from .codec.codec_context import CodecContext
if utils.USE_THP:
from .thp import thp_main
else:
from .codec.codec_context import CodecContext
from .context import UnexpectedMessageException from .context import UnexpectedMessageException
from .message_handler import failure from .message_handler import failure
@ -40,6 +45,8 @@ from .errors import * # isort:skip # noqa: F401,F403
_PROTOBUF_BUFFER_SIZE = const(8192) _PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_THP:
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -57,57 +64,89 @@ def setup(iface: WireInterface) -> None:
loop.schedule(handle_session(iface)) loop.schedule(handle_session(iface))
async def handle_session(iface: WireInterface) -> None: if utils.USE_THP:
ctx = CodecContext(iface, WIRE_BUFFER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can async def handle_session(iface: WireInterface) -> None:
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
if next_msg is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
continue
else: thp_main.set_read_buffer(WIRE_BUFFER)
# Process the message from previous run. thp_main.set_write_buffer(WIRE_BUFFER_2)
msg = next_msg
next_msg = None
do_not_restart = False # Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try: try:
do_not_restart = await message_handler.handle_single_message(ctx, msg) await thp_main.thp_main_loop(iface)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
# ...and we must not restart because that would lose the message.
do_not_restart = True
continue
except Exception as exc: except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the # Log and try again.
# following finally block.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
finally: finally:
# Unload modules imported by the workflow. Should not raise. # Unload modules imported by the workflow. Should not raise.
if __debug__:
log.debug(__name__, "utils.unimport_end(modules) and loop.clear()")
utils.unimport_end(modules) utils.unimport_end(modules)
loop.clear()
return # pylint: disable=lost-exception
if not do_not_restart: else:
# Let the session be restarted from `main`.
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc: async def handle_session(iface: WireInterface) -> None:
# Log and try again. The session handler can only exit explicitly via ctx = CodecContext(iface, WIRE_BUFFER)
# loop.clear() above. next_msg: protocol_common.Message | None = None
if __debug__:
log.exception(__name__, exc) # Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
if next_msg is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
continue
else:
# Process the message from previous run.
msg = next_msg
next_msg = None
do_not_restart = False
try:
do_not_restart = await message_handler.handle_single_message(
ctx, msg
)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
# ...and we must not restart because that would lose the message.
do_not_restart = True
continue
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
utils.unimport_end(modules)
if not do_not_restart:
# Let the session be restarted from `main`.
if __debug__:
log.debug(__name__, "loop.clear()")
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from storage import cache from storage import cache
from storage.cache_common import SESSIONLESS_FLAG from storage.cache_common import SESSIONLESS_FLAG
from trezor import loop, protobuf from trezor import loop, protobuf, utils
from .protocol_common import Context, Message from .protocol_common import Context, Message
@ -138,6 +138,17 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = None send_exc = None
def try_get_ctx_ids() -> tuple[bytes, bytes] | None:
ids = None
if utils.USE_THP:
from trezor.wire.thp.session_context import GenericSessionContext
ctx = get_context()
if isinstance(ctx, GenericSessionContext):
ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big"))
return ids
# ACCESS TO CACHE # ACCESS TO CACHE
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -8,6 +8,12 @@ class Error(Exception):
self.message = message self.message = message
class SilentError(Exception):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message
class UnexpectedMessage(Error): class UnexpectedMessage(Error):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message) super().__init__(FailureType.UnexpectedMessage, message)

View File

@ -25,7 +25,12 @@ def wrap_protobuf_load(
expected_type: type[LoadedMessageType], expected_type: type[LoadedMessageType],
) -> LoadedMessageType: ) -> LoadedMessageType:
try: try:
if __debug__ and utils.EMULATOR and utils.USE_THP: if (
__debug__
and utils.EMULATOR
and utils.USE_THP
and utils.ALLOW_DEBUG_MESSAGES
):
log.debug( log.debug(
__name__, __name__,
"Buffer to be parsed to a LoadedMessage: %s", "Buffer to be parsed to a LoadedMessage: %s",
@ -38,7 +43,7 @@ def wrap_protobuf_load(
) )
return msg return msg
except Exception as e: except Exception as e:
if __debug__: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e) log.exception(__name__, e)
if e.args: if e.args:
raise DataError("Failed to decode message: " + " ".join(e.args)) raise DataError("Failed to decode message: " + " ".join(e.args))
@ -46,6 +51,25 @@ def wrap_protobuf_load(
raise DataError("Failed to decode message") raise DataError("Failed to decode message")
if utils.USE_THP:
from trezor.enums import ThpMessageType
def get_msg_name(msg_type: int) -> str | None:
for name in dir(ThpMessageType):
if not name.startswith("__"): # Skip built-in attributes
value = getattr(ThpMessageType, name)
if isinstance(value, int):
if value == msg_type:
return name
return None
def get_msg_type(msg_name: str) -> int | None:
value = getattr(ThpMessageType, msg_name)
if isinstance(value, int):
return value
return None
async def handle_single_message(ctx: Context, msg: Message) -> bool: async def handle_single_message(ctx: Context, msg: Message) -> bool:
"""Handle a message that was loaded from a WireInterface by the caller. """Handle a message that was loaded from a WireInterface by the caller.
@ -60,17 +84,27 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
the type of message is supposed to be optimized and not disrupt the running state, the type of message is supposed to be optimized and not disrupt the running state,
this function will return `True`. this function will return `True`.
""" """
if __debug__: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
try: try:
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
except Exception: except Exception:
msg_type = f"{msg.type} - unknown message type" msg_type = f"{msg.type} - unknown message type"
log.debug( if utils.USE_THP:
__name__, cid = int.from_bytes(ctx.channel_id, "big")
"%d receive: <%s>", log.debug(
ctx.iface.iface_num(), __name__,
msg_type, "%d:%d receive: <%s>",
) ctx.iface.iface_num(),
cid,
msg_type,
)
else:
log.debug(
__name__,
"%d receive: <%s>",
ctx.iface.iface_num(),
msg_type,
)
res_msg: protobuf.MessageType | None = None res_msg: protobuf.MessageType | None = None
@ -91,7 +125,15 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
try: try:
# Find a protobuf.MessageType subclass that describes this # Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found. # message. Raises if the type is not found.
req_type = protobuf.type_for_wire(msg.type)
if utils.USE_THP:
name = get_msg_name(msg.type)
if name is None:
req_type = protobuf.type_for_wire(msg.type)
else:
req_type = protobuf.type_for_name(name)
else:
req_type = protobuf.type_for_wire(msg.type)
# Try to decode the message according to schema from # Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed. # `req_type`. Raises if the message is malformed.
@ -132,7 +174,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
# - the message was not valid protobuf # - the message was not valid protobuf
# - workflow raised some kind of an exception while running # - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside # - something canceled the workflow from the outside
if __debug__: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
if isinstance(exc, ActionCancelled): if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message) log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed): elif isinstance(exc, loop.TaskClosed):

View File

@ -4,7 +4,7 @@ from trezor import protobuf
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
from typing import Container, TypeVar, overload from typing import Awaitable, Container, TypeVar, overload
from storage.cache_common import DataCache from storage.cache_common import DataCache
@ -72,6 +72,9 @@ class Context:
"""Write a message to the wire.""" """Write a message to the wire."""
... ...
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.write(msg)
async def call( async def call(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,

View File

@ -0,0 +1,184 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf, utils
from trezor.enums import ThpPairingMethod
from trezor.messages import ThpDeviceProperties
from ..protocol_common import WireError
if TYPE_CHECKING:
from enum import IntEnum
from trezor.wire import WireInterface
from typing_extensions import Self
else:
IntEnum = object
CODEC_V1 = const(0x3F)
HANDSHAKE_INIT_REQ = const(0x00)
HANDSHAKE_INIT_RES = const(0x01)
HANDSHAKE_COMP_REQ = const(0x02)
HANDSHAKE_COMP_RES = const(0x03)
ENCRYPTED = const(0x04)
ACK_MESSAGE = const(0x20)
CHANNEL_ALLOCATION_REQ = const(0x40)
_CHANNEL_ALLOCATION_RES = const(0x41)
_ERROR = const(0x42)
CONTINUATION_PACKET = const(0x80)
class ThpError(WireError):
pass
class ThpDecryptionError(ThpError):
pass
class ThpInvalidDataError(ThpError):
pass
class ThpUnallocatedSessionError(ThpError):
def __init__(self, session_id: int) -> None:
self.session_id = session_id
class ThpErrorType(IntEnum):
TRANSPORT_BUSY = 1
UNALLOCATED_CHANNEL = 2
DECRYPTION_FAILED = 3
INVALID_DATA = 4
class ChannelState(IntEnum):
UNALLOCATED = 0
TH1 = 1
TH2 = 2
TP1 = 3
TP2 = 4
TP3 = 5
TP4 = 6
TC1 = 7
ENCRYPTED_TRANSPORT = 8
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
MANAGEMENT = 2
class PacketHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
"""
Packs header information in the form of **intial** packet
into the provided buffer.
"""
ustruct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
"""
Packs header information in the form of **continuation** packet header
into the provided buffer.
"""
ustruct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
@classmethod
def get_error_header(cls, cid: int, length: int) -> Self:
"""
Returns header for protocol-level error messages.
"""
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length: int) -> Self:
"""
Returns header for allocation response handshake message.
"""
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
_DEFAULT_ENABLED_PAIRING_METHODS = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.QrCode,
ThpPairingMethod.NFC_Unidirectional,
]
def get_enabled_pairing_methods(
iface: WireInterface | None = None,
) -> list[ThpPairingMethod]:
"""
Returns pairing methods that are currently allowed by the device
with respect to the wire interface the host communicates on.
"""
import usb
methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy()
if iface is not None and iface is usb.iface_wire:
methods.append(ThpPairingMethod.NoMethod)
return methods
def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
# TODO define model variants
return ThpDeviceProperties(
pairing_methods=get_enabled_pairing_methods(iface),
internal_model=utils.INTERNAL_MODEL,
model_variant=0,
bootloader_mode=False,
protocol_version=2,
)
def get_encoded_device_properties(iface: WireInterface) -> bytes:
props = _get_device_properties(iface)
length = protobuf.encoded_length(props)
encoded_properties = bytearray(length)
protobuf.encode(encoded_properties, props)
return encoded_properties
def get_channel_allocation_response(
nonce: bytes, new_cid: bytes, iface: WireInterface
) -> bytes:
props_msg = get_encoded_device_properties(iface)
return nonce + new_cid + props_msg
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

View File

@ -0,0 +1,102 @@
from storage.cache_thp import ChannelCache
from trezor import log, utils
from trezor.wire.thp import ThpError
def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
"""
Checks if:
- an ACK message is expected
- the received ACK message acknowledges correct sequence number (bit)
"""
if not _is_ack_expected(cache):
return False
if not _has_ack_correct_sync_bit(cache, ack_bit):
return False
return True
def _is_ack_expected(cache: ChannelCache) -> bool:
is_expected: bool = not is_sending_allowed(cache)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected:
log.debug(__name__, "Received unexpected ACK message")
return is_expected
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
is_correct: bool = get_send_seq_bit(cache) == sync_bit
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct:
log.debug(__name__, "Received ACK message with wrong ack bit")
return is_correct
def is_sending_allowed(cache: ChannelCache) -> bool:
"""
Checks whether sending a message in the provided channel is allowed.
Note: Sending a message in a channel before receipt of ACK message for the previously
sent message (in the channel) is prohibited, as it can lead to desynchronization.
"""
return bool(cache.sync >> 7)
def get_send_seq_bit(cache: ChannelCache) -> int:
"""
Returns the sequential number (bit) of the next message to be sent
in the provided channel.
"""
return (cache.sync & 0x20) >> 5
def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
"""
Returns the (expected) sequential number (bit) of the next message
to be received in the provided channel.
"""
return (cache.sync & 0x40) >> 6
def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
"""
Set the flag whether sending a message in this channel is allowed or not.
"""
cache.sync &= 0x7F
if sending_allowed:
cache.sync |= 0x80
def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
"""
Set the expected sequential number (bit) of the next message to be received
in the provided channel
"""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
if seq_bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")
# set second bit to "seq_bit" value
cache.sync &= 0xBF
if seq_bit:
cache.sync |= 0x40
def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
if seq_bit not in (0, 1):
raise ThpError("Unexpected send seq bit")
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# set third bit to "seq_bit" value
cache.sync &= 0xDF
if seq_bit:
cache.sync |= 0x20
def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
"""
Set the sequential bit of the "next message to be send" to the opposite value,
i.e. 1 -> 0 and 0 -> 1
"""
_set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,405 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
from trezor import log, loop, protobuf, utils, workflow
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
from . import alternating_bit_protocol as ABP
from . import (
control_byte,
crypto,
interface_manager,
memory_manager,
received_message_handler,
)
from .checksum import CHECKSUM_LENGTH
from .transmission_loop import TransmissionLoop
from .writer import (
CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if __debug__:
from ubinascii import hexlify
from . import state_to_str
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable
from .pairing_context import PairingContext
from .session_context import GenericSessionContext
class Channel:
"""
THP protocol encrypted communication channel.
"""
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "channel initialization")
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, GenericSessionContext] = {}
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
self.transmission_loop: TransmissionLoop | None = None
self.handshake: crypto.Handshake | None = None
def clear(self) -> None:
clear_sessions_with_channel_id(self.channel_id)
self.channel_cache.clear()
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) get_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
return state
def get_handshake_hash(self) -> bytes:
h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
assert h is not None
return h
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_buffer: %s",
utils.get_bytes_as_str(self.channel_id),
type(self.buffer),
)
# CALLED BY THP_MAIN_LOOP
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) receive_packet",
utils.get_bytes_as_str(self.channel_id),
)
self._handle_received_packet(packet)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) self.buffer: %s",
utils.get_bytes_as_str(self.channel_id),
utils.get_bytes_as_str(self.buffer),
)
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message()
return received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
self.is_cont_packet_expected = True
else:
raise ThpError(
"Read more bytes than is the expected length of the message!"
)
return None
def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if control_byte.is_continuation(ctrl_byte):
return self._handle_cont_packet(packet)
return self._handle_init_packet(packet)
def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet",
utils.get_bytes_as_str(self.channel_id),
)
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
# TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
# On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
# in "memory_manager.select_buffer", because the session id is unknown (encrypted).
# if control_byte.is_encrypted_transport(ctrl_byte):
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet - payload len: %d",
utils.get_bytes_as_str(self.channel_id),
payload_length,
)
log.debug(
__name__,
"(cid: %s) handle_init_packet - buffer len: %d",
utils.get_bytes_as_str(self.channel_id),
len(self.buffer),
)
return self._buffer_packet_data(self.buffer, packet, 0)
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_cont_packet",
utils.get_bytes_as_str(self.channel_id),
)
if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring")
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
def _decrypt_single_packet_payload(
self, payload: utils.BufferType
) -> utils.BufferType:
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload
def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None:
noise_buffer = memoryview(self.buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
tag = self.buffer[
message_length
- CHECKSUM_LENGTH
- TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
if utils.DISABLE_ENCRYPTION:
is_tag_valid = tag == crypto.DUMMY_TAG
else:
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
assert key_receive is not None
assert nonce_receive is not None
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer before decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
is_tag_valid = crypto.dec(
noise_buffer, tag, key_receive, nonce_receive, b""
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer after decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Is decrypted tag valid? %s",
utils.get_bytes_as_str(self.channel_id),
str(is_tag_valid),
)
log.debug(
__name__,
"(cid: %s) Received tag: %s",
utils.get_bytes_as_str(self.channel_id),
(hexlify(tag).decode()),
)
log.debug(
__name__,
"(cid: %s) New nonce_receive: %i",
utils.get_bytes_as_str(self.channel_id),
nonce_receive + 1,
)
if not is_tag_valid:
raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
)
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
assert key_send is not None
assert nonce_send is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
) -> None:
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self) -> None:
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(
self,
msg: protobuf.MessageType,
session_id: int = 0,
force: bool = False,
) -> None:
if __debug__ and utils.EMULATOR:
log.debug(
__name__,
"(cid: %s) write message: %s\n%s",
utils.get_bytes_as_str(self.channel_id),
msg.MESSAGE_NAME,
utils.dump_protobuf(msg),
)
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
if task is not None:
await task
def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big")
length = len(msg_data) + CHECKSUM_LENGTH
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
def write_and_encrypt(
self, payload: bytes, force: bool = False
) -> Awaitable[None] | None:
payload_length = len(payload)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
if self.write_task_spawn is not None:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
if force:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "Writing FORCE message (without async or retransmission)."
)
return self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
)
return None
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
ABP.set_sending_allowed(self.channel_cache, False)
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid %s) write_encrypted_payload_loop",
utils.get_bytes_as_str(self.channel_id),
)
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
self.transmission_loop = TransmissionLoop(self, header, payload)
await self.transmission_loop.start()
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
# Let the main loop be restarted and clear loop, if there is no other
# workflow and the state is ENCRYPTED_TRANSPORT
if self._can_clear_loop():
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) clearing loop from channel",
utils.get_bytes_as_str(self.channel_id),
)
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

View File

@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
"""
Creates a new channel for the interface `iface` with the buffer `buffer`.
"""
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
"""
Returns all allocated channels from cache.
"""
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

View File

@ -0,0 +1,22 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,50 @@
from micropython import const
from . import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
ThpError,
)
_CONTINUATION_PACKET_MASK = const(0x80)
_ACK_MASK = const(0xF7)
_DATA_MASK = const(0xE7)
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise ThpError("Unexpected acknowledgement bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & _ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == ENCRYPTED
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,36 @@
from trezor.crypto import elligator2, random
from trezor.crypto.curve import curve25519
from trezor.crypto.hashlib import sha512
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.host_public_key: bytes = cpace_host_public_key
self.shared_secret: bytes
self.trezor_private_key: bytes
self.trezor_public_key: bytes
def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = elligator2.map_to_curve25519(pregenerator)
self.trezor_private_key = random.bytes(32)
self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
self.shared_secret = curve25519.multiply(
self.trezor_private_key, self.host_public_key
)

View File

@ -0,0 +1,211 @@
from micropython import const
from trezorcrypto import aesgcm, bip32, curve25519, hmac
from storage import device
from trezor import log, utils
from trezor.crypto.hashlib import sha256
from trezor.wire.thp import ThpDecryptionError
# The HARDENED flag is taken from apps.common.paths
# It is not imported to save on resources
HARDENED = const(0x8000_0000)
PUBKEY_LENGTH = const(32)
if utils.DISABLE_ENCRYPTION:
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
if __debug__:
from ubinascii import hexlify
def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
"""
Encrypts the provided `buffer` with AES-GCM (in place).
Returns a 16-byte long encryption tag.
"""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce)
iv = _get_iv_from_nonce(nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.encrypt_in_place(buffer)
return aes_ctx.finish()
def dec(
buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
) -> bool:
"""
Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
the tag computed in decryption, otherwise it returns `False`.
"""
iv = _get_iv_from_nonce(nonce)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce)
aes_ctx = aesgcm(key, iv)
aes_ctx.auth(auth_data)
aes_ctx.decrypt_in_place(buffer)
computed_tag = aes_ctx.finish()
return computed_tag == tag
class BusyDecoder:
def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
iv = _get_iv_from_nonce(nonce)
self.aes_ctx = aesgcm(key, iv)
self.aes_ctx.auth(auth_data)
def decrypt_part(self, part: utils.BufferType) -> None:
self.aes_ctx.decrypt_in_place(part)
def finish_and_check_tag(self, tag: bytes) -> bool:
computed_tag = self.aes_ctx.finish()
return computed_tag == tag
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
class Handshake:
"""
`Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
The following values should be saved for future use before disposing of this object:
- `h` - handshake hash, can be used to bind other values to the channel
- `key_receive` - key for decrypting incoming communication
- `key_send` - key for encrypting outgoing communication
"""
def __init__(self) -> None:
self.trezor_ephemeral_privkey: bytes
self.ck: bytes
self.k: bytes
self.h: bytes
self.key_receive: bytes
self.key_send: bytes
def handle_th1_crypto(
self,
device_properties: bytes,
host_ephemeral_pubkey: bytes,
) -> tuple[bytes, bytes, bytes]:
trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
self.trezor_ephemeral_privkey = curve25519.generate_secret()
trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
point = curve25519.multiply(
self.trezor_ephemeral_privkey, host_ephemeral_pubkey
)
self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
aes_ctx = aesgcm(self.k, IV_1)
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
if __debug__:
log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0)
aes_ctx.auth(self.h)
tag_to_encrypted_key = aes_ctx.finish()
encrypted_trezor_static_pubkey = (
encrypted_trezor_static_pubkey + tag_to_encrypted_key
)
self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
tag = aes_ctx.finish()
self.h = _hash_of_two(self.h, tag)
return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
def handle_th2_crypto(
self,
encrypted_host_static_pubkey: utils.BufferType,
encrypted_payload: utils.BufferType,
) -> None:
aes_ctx = aesgcm(self.k, IV_2)
# The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
# However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
# authentication of the gcm tag.
aes_ctx.auth(self.h) # Authenticate with the previous value of `h`
self.h = _hash_of_two(self.h, encrypted_host_static_pubkey) # Compute new value
aes_ctx.decrypt_in_place(
memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
)
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1)
host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
tag = aes_ctx.finish()
if tag != encrypted_host_static_pubkey[-16:]:
raise ThpDecryptionError()
self.ck, self.k = _hkdf(
self.ck,
curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(self.k, IV_1)
aes_ctx.auth(self.h)
aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
if __debug__:
log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0)
tag = aes_ctx.finish()
if tag != encrypted_payload[-16:]:
raise ThpDecryptionError()
self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
self.key_receive, self.key_send = _hkdf(self.ck, b"")
if __debug__:
log.debug(
__name__,
"(key_receive: %s, key_send: %s)",
hexlify(self.key_receive),
hexlify(self.key_send),
)
def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
aes_ctx = aesgcm(self.key_send, IV_1)
encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
tag = aes_ctx.finish()
return encrypted_trezor_state + tag
def _derive_static_key_pair() -> tuple[bytes, bytes]:
node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
node = bip32.from_seed(device.get_device_secret(), "curve25519")
node.derive(node_int)
trezor_static_privkey = node.private_key()
trezor_static_pubkey = node.public_key()[1:33]
# Note: the first byte (\x01) of the public key is removed, as it
# only indicates the type of the elliptic curve used
return trezor_static_privkey, trezor_static_pubkey
def get_trezor_static_pubkey() -> bytes:
_, pubkey = _derive_static_key_pair()
return pubkey
def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]:
temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
ctx = sha256(part_1)
ctx.update(part_2)
return ctx.digest()
def _get_iv_from_nonce(nonce: int) -> bytes:
utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")

View File

@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
import usb
_WIRE_INTERFACE_USB = b"\x01"
# TODO _WIRE_INTERFACE_BLE = b"\x02"
if TYPE_CHECKING:
from trezorio import WireInterface
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")

View File

@ -0,0 +1,179 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from trezor.wire.message_handler import get_msg_type
from . import ChannelState, ThpError
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
return buffer
except Exception as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def get_write_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType
) -> utils.BufferType:
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
return _get_buffer_for_write(required_min_size, buffer)
return buffer
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
msg_type = msg.MESSAGE_WIRE_TYPE
if msg_type is None:
msg_type = get_msg_type(msg.MESSAGE_NAME)
assert msg_type is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_read(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_read - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Allocating a new buffer")
from .thp_main import get_raw_read_buffer
if length > len(get_raw_read_buffer()):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Required length is %d, where raw buffer has capacity only %d",
length,
len(get_raw_read_buffer()),
)
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]
def _get_buffer_for_write(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_write - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Creating a new write buffer from raw write buffer")
from .thp_main import get_raw_write_buffer
if length > len(get_raw_write_buffer()):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
except MemoryError:
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]

View File

@ -0,0 +1,262 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
import trezorui_api
from trezor import loop, protobuf, workflow
from trezor.crypto import random
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError
from trezor.wire.protocol_common import Context, Message
if TYPE_CHECKING:
from typing import Container
from trezor import ui
from .channel import Channel
from .cpace import Cpace
pass
if __debug__:
from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc_unidirectional: bytes | None = None
def get_display_layout(self) -> ui.Layout:
from trezor import ui
# TODO have different layouts when there is only QR code or only Code Entry
qr_str = ""
code_str = ""
if self.code_qr_code is not None:
qr_str = self._get_code_qr_code_str()
if self.code_code_entry is not None:
code_str = self._get_code_code_entry_str()
return ui.Layout(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=qr_str,
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + code_str,
path="",
xpubs=[],
)
)
def _get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def _get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.mailbox()
self.secret: bytes = random.bytes(16)
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str
async def handle(self, is_debug_session: bool = False) -> None:
# if __debug__:
# log.debug(__name__, "handle - start")
# if is_debug_session:
# import apps.debug
# apps.debug.DEBUG_CONTEXT = self
next_message: Message | None = None
while True:
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: Message = await self.incoming_message
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(message_handler.failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await handle_pairing_request_message(self, message)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None:
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above. # TODO not updated comments
if __debug__:
log.exception(__name__, exc)
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message
if message.type not in expected_types:
raise UnexpectedMessageException(message)
if expected_type is None:
name = message_handler.get_msg_name(message.type)
if name is None:
expected_type = protobuf.type_for_wire(message.type)
else:
expected_type = protobuf.type_for_name(name)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
if expected_wire_type is None:
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
assert expected_wire_type is not None
await self.write(msg)
del msg
return await self.read((expected_wire_type,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
async def handle_pairing_request_message(
pairing_ctx: PairingContext,
msg: protocol_common.Message,
) -> protocol_common.Message | None:
res_msg: protobuf.MessageType | None = None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
# Here we make sure we always respond with a Failure response
# in case of any errors.
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
name = message_handler.get_msg_name(msg.type)
if name is None:
req_type = protobuf.type_for_wire(msg.type)
else:
req_type = protobuf.type_for_name(name)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
except UnexpectedMessageException as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
return exc.msg
except SilentError as exc:
if __debug__:
log.error(__name__, "SilentError: %s", exc.message)
except BaseException as exc:
# Either:
# - the message had a type that has a registered handler, but does not have
# a protobuf class
# - the message was not valid protobuf
# - workflow raised some kind of an exception while running
# - something canceled the workflow from the outside
if __debug__:
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: %s", exc.message)
elif isinstance(exc, loop.TaskClosed):
log.debug(__name__, "cancelled: loop task was closed")
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await pairing_ctx.write(res_msg)
return None

View File

@ -0,0 +1,446 @@
import ustruct
from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from storage.cache_thp import (
KEY_LENGTH,
MANAGEMENT_SESSION_ID,
SESSION_ID_LENGTH,
TAG_LENGTH,
update_channel_last_used,
update_session_last_used,
)
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import Failure
from .. import message_handler
from ..errors import DataError
from ..protocol_common import Message
from . import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
ChannelState,
PacketHeader,
SessionState,
ThpDecryptionError,
ThpError,
ThpErrorType,
ThpInvalidDataError,
ThpUnallocatedSessionError,
)
from . import alternating_bit_protocol as ABP
from . import (
checksum,
control_byte,
get_enabled_pairing_methods,
get_encoded_device_properties,
session_manager,
)
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH, Handshake
from .writer import (
INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from typing import Awaitable
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from .channel import Channel
if __debug__:
from ubinascii import hexlify
from . import state_to_str
_TREZOR_STATE_UNPAIRED = b"\x00"
_TREZOR_STATE_PAIRED = b"\x01"
async def handle_received_message(
ctx: Channel, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_received_message")
if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done
try:
import micropython
print("micropython.mem_info() from received_message_handler.py")
micropython.mem_info()
print("Allocation count:", micropython.alloc_count()) # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
except AttributeError:
print(
"To show allocation count, create the build with TREZOR_MEMPERF=1"
)
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_HEADER_LENGTH
_check_checksum(message_length, message_buffer)
# Synchronization process
seq_bit = (ctrl_byte & 0x10) >> 4
ack_bit = (ctrl_byte & 0x08) >> 3
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
seq_bit,
ack_bit,
)
# 0: Update "last-time used"
update_channel_last_used(ctx.channel_id)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, ack_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected sequential bit
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Received message with an unexpected sequential bit")
await _send_ack(ctx, ack_bit=seq_bit)
raise ThpError("Received message with an unexpected sequential bit")
# 3: Send ACK in response
await _send_ack(ctx, ack_bit=seq_bit)
ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
try:
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte
)
except ThpUnallocatedSessionError as e:
error_message = Failure(code=FailureType.ThpUnallocatedSession)
await ctx.write(error_message, e.session_id)
except ThpDecryptionError:
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
ctx.clear()
except ThpInvalidDataError:
await ctx.write_error(ThpErrorType.INVALID_DATA)
ctx.clear()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_received_message - end")
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, ack_bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
return
# ACK is expected and it has correct sync bit
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Received ACK message with correct ack bit")
if ctx.transmission_loop is not None:
ctx.transmission_loop.stop_immediately()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Stopped transmission loop")
ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in termination of this function - any code after
# this await might not be executed
def _handle_message_to_app_or_channel(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> Awaitable[None]:
state = ctx.get_channel_state()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
if state is ChannelState.TH1:
return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
if state is ChannelState.TH2:
return _handle_state_TH2(ctx, message_length, ctrl_byte)
if _is_channel_state_pairing(state):
return _handle_pairing(ctx, message_length)
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: Channel,
payload_length: int,
message_length: int,
ctrl_byte: int,
) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
ctx.handshake = Handshake()
host_ephemeral_pubkey = bytearray(
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
)
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
ctx.handshake.handle_th1_crypto(
get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey
)
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"trezor ephemeral pubkey: %s",
hexlify(trezor_ephemeral_pubkey).decode(),
)
log.debug(
__name__,
"encrypted trezor masked static pubkey: %s",
hexlify(encrypted_trezor_static_pubkey).decode(),
)
log.debug(__name__, "tag: %s", hexlify(tag))
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
# send handshake init response message
ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
from apps.thp.credential_manager import validate_credential
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
if ctx.handshake is None:
raise Exception("Handshake object is not prepared. Retry handshake.")
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
ctx.handshake.handle_th2_crypto(
host_encrypted_static_pubkey, handshake_completion_request_noise_payload
)
ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
noise_payload = _decode_message(
ctx.buffer[
INIT_HEADER_LENGTH
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
enabled_methods = get_enabled_pairing_methods(ctx.iface)
for method in noise_payload.pairing_methods:
if method not in enabled_methods:
raise ThpInvalidDataError()
if method not in ctx.selected_pairing_methods:
ctx.selected_pairing_methods.append(method)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# key is decoded in handshake._handle_th2_crypto
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
paired: bool = False
if noise_payload.host_pairing_credential is not None:
try: # TODO change try-except for something better
paired = validate_credential(
noise_payload.host_pairing_credential,
host_static_pubkey,
)
except DataError as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
pass
trezor_state = _TREZOR_STATE_UNPAIRED
if paired:
trezor_state = _TREZOR_STATE_PAIRED
# send hanshake completion response
ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
ctx.handshake.get_handshake_completion_response(trezor_state),
)
ctx.handshake = None
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
)
if session_id not in ctx.sessions:
if session_id == MANAGEMENT_SESSION_ID:
s = session_manager.create_new_management_session(ctx)
else:
s = session_manager.get_session_from_cache(ctx, session_id)
if s is None:
raise ThpUnallocatedSessionError(session_id)
ctx.sessions[session_id] = s
loop.schedule(s.handle())
elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED:
raise ThpUnallocatedSessionError(session_id)
s = ctx.sessions[session_id]
update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big"))
s.incoming_message.put(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: Channel, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.put(
Message(
message_type,
ctx.buffer[
INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
def _decode_message(
buffer: bytes, msg_type: int, message_name: str | None = None
) -> protobuf.MessageType:
if __debug__:
log.debug(__name__, "decode message")
if message_name is not None:
expected_type = protobuf.type_for_name(message_name)
else:
expected_type = protobuf.type_for_wire(msg_type)
return message_handler.wrap_protobuf_load(buffer, expected_type)
def _is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False

View File

@ -0,0 +1,169 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import failure
from ..protocol_common import Context, Message
from . import SessionState
if TYPE_CHECKING:
from typing import Awaitable, Container
from storage.cache_common import DataCache
from .channel import Channel
pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
if __debug__:
from trezor.utils import get_bytes_as_str
class GenericSessionContext(Context):
def __init__(self, channel: Channel, session_id: int) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel
self.session_id: int = session_id
self.incoming_message = loop.mailbox()
async def handle(self) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._handle_debug()
next_message: Message | None = None
while True:
message = next_message
next_message = None
try:
if await self._handle_message(message):
loop.schedule(self.handle())
return
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_message = unexpected.msg
continue
except Exception as exc:
# Log and try again.
if __debug__:
log.exception(__name__, exc)
def _handle_debug(self) -> None:
log.debug(
__name__,
"handle - start (channel_id (bytes): %s, session_id: %d)",
get_bytes_as_str(self.channel_id),
self.session_id,
)
async def _handle_message(
self,
next_message: Message | None,
) -> bool:
try:
if next_message is not None:
# Process the message from previous run.
message = next_message
next_message = None
else:
# Wait for a new message from wire
message = await self.incoming_message
except protocol_common.WireError as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
await message_handler.handle_single_message(self, message)
return _EXIT_LOOP
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
exp_type: str = str(expected_type)
if expected_type is not None:
exp_type = expected_type.MESSAGE_NAME
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message
if message.type not in expected_types:
if __debug__:
log.debug(
__name__,
"EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
str(expected_types),
str(message.type),
)
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.channel.write(msg, self.session_id, force=True)
def get_session_state(self) -> SessionState: ...
class ManagementSessionContext(GenericSessionContext):
def __init__(
self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID
) -> None:
super().__init__(channel_ctx, session_id)
def get_session_state(self) -> SessionState:
return SessionState.MANAGEMENT
class SessionContext(GenericSessionContext):
def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
session_id = int.from_bytes(session_cache.session_id, "big")
super().__init__(channel_ctx, session_id)
self.session_cache = session_cache
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
state = int.from_bytes(self.session_cache.state, "big")
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def release(self) -> None:
if self.session_cache is not None:
cache_thp.clear_session(self.session_cache)
# ACCESS TO CACHE
@property
def cache(self) -> DataCache:
return self.session_cache

View File

@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from .session_context import (
GenericSessionContext,
ManagementSessionContext,
SessionContext,
)
if TYPE_CHECKING:
from .channel import Channel
def create_new_session(channel_ctx: Channel) -> SessionContext:
"""
Creates new `SessionContext` backed by cache.
"""
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def create_new_management_session(
channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
) -> ManagementSessionContext:
"""
Creates new `ManagementSessionContext` that is not backed by cache entry.
Seed cannot be derived with this type of session.
"""
return ManagementSessionContext(channel_ctx, session_id)
def get_session_from_cache(
channel_ctx: Channel, session_id: int
) -> GenericSessionContext | None:
"""
Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found.
"""
session_id_bytes = session_id.to_bytes(1, "big")
session_cache = cache_thp.get_allocated_session(
channel_ctx.channel_id, session_id_bytes
)
if session_cache is None:
return None
elif cache_thp.is_management_session(session_cache):
return ManagementSessionContext(channel_ctx, session_id)
return SessionContext(channel_ctx, session_cache)

View File

@ -0,0 +1,187 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils
from . import (
CHANNEL_ALLOCATION_REQ,
CODEC_V1,
ChannelState,
PacketHeader,
ThpError,
ThpErrorType,
channel_manager,
checksum,
get_channel_allocation_response,
writer,
)
from .channel import Channel
from .checksum import CHECKSUM_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
PACKET_LENGTH,
write_payload_to_wire_and_add_checksum,
)
if TYPE_CHECKING:
from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12)
_READ_BUFFER: bytearray
_WRITE_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {}
def set_read_buffer(buffer: bytearray) -> None:
global _READ_BUFFER
_READ_BUFFER = buffer
def set_write_buffer(buffer: bytearray) -> None:
global _WRITE_BUFFER
_WRITE_BUFFER = buffer
def get_raw_read_buffer() -> bytearray:
global _READ_BUFFER
return _READ_BUFFER
def get_raw_write_buffer() -> bytearray:
global _WRITE_BUFFER
return _WRITE_BUFFER
async def thp_main_loop(iface: WireInterface) -> None:
global _CHANNELS
global _READ_BUFFER
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
try:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "thp_main_loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
await _handle_codec_v1(iface, packet)
continue
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in _CHANNELS:
await _handle_allocated(iface, cid, packet)
else:
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
log.exception(__name__, e)
async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
# If the received packet is not an initial codec_v1 packet, do not send error message
if not packet[1:3] == b"##":
return
if __debug__:
log.debug(__name__, "Received codec_v1 message, returning error")
error_message = _get_codec_v1_error_message()
await writer.write_packet_to_wire(iface, error_message)
async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None:
global _READ_BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on the broadcast channel")
length, nonce = ustruct.unpack(">H8s", packet[3:])
payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
if not checksum.is_valid(
payload[-4:],
packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
):
raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel
response_data = get_channel_allocation_response(
nonce, new_channel.channel_id, iface
)
response_header = PacketHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", cid)
await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = _CHANNELS[cid]
if channel is None:
await _handle_unallocated(iface, cid)
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
x = channel.receive_packet(packet)
if x is not None:
await x
async def _handle_unallocated(iface: WireInterface, cid: int) -> None:
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
await write_payload_to_wire_and_add_checksum(iface, header, data)
def _get_buffer_for_payload(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(PACKET_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]
def _get_codec_v1_error_message() -> bytes:
# Codec_v1 magic constant "?##" + Failure message type + msg_size
# + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B
ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
return ERROR_MSG

View File

@ -0,0 +1,54 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import loop
from .writer import write_payload_to_wire_and_add_checksum
if TYPE_CHECKING:
from . import PacketHeader
from .channel import Channel
MAX_RETRANSMISSION_COUNT = const(50)
MIN_RETRANSMISSION_COUNT = const(2)
class TransmissionLoop:
def __init__(
self, channel: Channel, header: PacketHeader, transport_payload: bytes
) -> None:
self.channel: Channel = channel
self.header: PacketHeader = header
self.transport_payload: bytes = transport_payload
self.wait_task: loop.spawn | None = None
self.min_retransmisson_count_achieved: bool = False
async def start(
self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT
) -> None:
self.min_retransmisson_count_achieved = False
for i in range(max_retransmission_count):
if i >= MIN_RETRANSMISSION_COUNT:
self.min_retransmisson_count_achieved = True
await write_payload_to_wire_and_add_checksum(
self.channel.iface, self.header, self.transport_payload
)
self.wait_task = loop.spawn(self._wait(i))
try:
await self.wait_task
except loop.TaskClosed:
self.wait_task = None
break
def stop_immediately(self) -> None:
if self.wait_task is not None:
self.wait_task.close()
self.wait_task = None
async def _wait(self, counter: int = 0) -> None:
timeout_ms = round(10200 - 1010000 / (counter + 100))
await loop.sleep(timeout_ms)
def __del__(self) -> None:
self.stop_immediately()

View File

@ -0,0 +1,92 @@
from micropython import const
from trezorcrypto import crc
from typing import TYPE_CHECKING
from trezor import io, log, loop, utils
from . import PacketHeader
INIT_HEADER_LENGTH = const(5)
CONT_HEADER_LENGTH = const(3)
PACKET_LENGTH = const(64)
CHECKSUM_LENGTH = const(4)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable, Sequence
def write_payload_to_wire_and_add_checksum(
iface: WireInterface, header: PacketHeader, transport_payload: bytes
) -> Awaitable[None]:
header_checksum: int = crc.crc32(header.to_bytes())
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
CHECKSUM_LENGTH, "big"
)
data = (transport_payload, checksum)
return write_payloads_to_wire(iface, header, data)
async def write_payloads_to_wire(
iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
) -> None:
n_of_data = len(data)
total_length = sum(len(item) for item in data)
current_data_idx = 0
current_data_offset = 0
packet = bytearray(PACKET_LENGTH)
header.pack_to_init_buffer(packet)
packet_offset: int = INIT_HEADER_LENGTH
packet_number = 0
nwritten = 0
while nwritten < total_length:
if packet_number == 1:
header.pack_to_cont_buffer(packet)
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
packet[:] = bytearray(PACKET_LENGTH)
header.pack_to_cont_buffer(packet)
while True:
n = utils.memcpy(
packet, packet_offset, data[current_data_idx], current_data_offset
)
packet_offset += n
current_data_offset += n
nwritten += n
if packet_offset < PACKET_LENGTH:
current_data_idx += 1
current_data_offset = 0
if current_data_idx >= n_of_data:
break
elif packet_offset == PACKET_LENGTH:
break
else:
raise Exception("Should not happen!!!")
packet_number += 1
packet_offset = CONT_HEADER_LENGTH
# write packet to wire (in-lined)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
written_by_iface: int = 0
while written_by_iface < len(packet):
await loop.wait(iface.iface_num() | io.POLL_WRITE)
written_by_iface = iface.write(packet)
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
while True:
await loop.wait(iface.iface_num() | io.POLL_WRITE)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
)
n_written = iface.write(packet)
if n_written == len(packet):
return

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
import storage.cache as storage_cache import storage.cache as storage_cache
from trezor import log, loop from trezor import log, loop
from trezor.enums import MessageType from trezor.enums import MessageType, ThpMessageType
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable from typing import Callable
@ -17,9 +17,14 @@ if __debug__:
from trezor import utils from trezor import utils
if utils.USE_THP:
protocol_specific = ThpMessageType.ThpCreateNewSession
else:
protocol_specific = MessageType.Initialize
ALLOW_WHILE_LOCKED = ( ALLOW_WHILE_LOCKED = (
MessageType.Initialize, protocol_specific,
MessageType.EndSession, MessageType.EndSession,
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.Cancel, MessageType.Cancel,

View File

@ -0,0 +1,17 @@
from trezor.loop import wait
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)

42
core/tests/myTests.sh Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env bash
declare -a results
declare -i passed=0 failed=0 exit_code=0
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
print_summary() {
echo
echo 'Summary:'
echo '-------------------'
printf '%b\n' "${results[@]}"
if [ $exit_code == 0 ]; then
echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
else
echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
fi
}
trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
cd $(dirname $0)
[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
declare -i num_of_tests=${#tests[@]}
for test_case in ${tests[@]}; do
echo ${MICROPYTHON}
echo ${test_case}
echo
if $MICROPYTHON $test_case; then
results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
((passed++))
else
results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
((failed++))
exit_code=1
fi
done
print_summary
exit $exit_code

View File

@ -1,4 +1,4 @@
from common import H_, await_result, unittest # isort:skip from common import * # isort:skip
import storage.cache_codec import storage.cache_codec
from trezor import wire from trezor import wire
@ -12,19 +12,33 @@ from trezor.messages import (
TxOutput, TxOutput,
) )
from trezor.wire import context from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
from apps.bitcoin.sign_tx.bitcoin import Bitcoin from apps.bitcoin.sign_tx.bitcoin import Bitcoin
from apps.bitcoin.sign_tx.tx_info import TxInfo from apps.bitcoin.sign_tx.tx_info import TxInfo
from apps.common import coins from apps.common import coins
from trezor.wire.codec.codec_context import CodecContext
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
from trezor.wire.codec.codec_context import CodecContext
class TestApprover(unittest.TestCase): class TestApprover(unittest.TestCase):
if utils.USE_THP:
def setUpClass(self): def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
@ -54,7 +68,8 @@ class TestApprover(unittest.TestCase):
coin_name=self.coin.coin_name, coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDTAPROOT, script_type=InputScriptType.SPENDTAPROOT,
) )
storage.cache_codec.start_session() if not utils.USE_THP:
storage.cache_codec.start_session()
def make_coinjoin_request(self, inputs): def make_coinjoin_request(self, inputs):
return CoinJoinRequest( return CoinJoinRequest(

View File

@ -1,23 +1,37 @@
from common import H_, unittest # isort:skip from common import * # isort:skip
import storage.cache_codec import storage.cache_codec
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
from trezor.wire import context from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.bitcoin.authorization import CoinJoinAuthorization from apps.bitcoin.authorization import CoinJoinAuthorization
from apps.common import coins from apps.common import coins
_ROUND_ID_LEN = 32 _ROUND_ID_LEN = 32
if utils.USE_THP:
import thp_common
else:
import storage.cache_codec
from trezor.wire.codec.codec_context import CodecContext
class TestAuthorization(unittest.TestCase): class TestAuthorization(unittest.TestCase):
coin = coins.by_name("Bitcoin") coin = coins.by_name("Bitcoin")
def setUpClass(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUpClass(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
@ -34,7 +48,8 @@ class TestAuthorization(unittest.TestCase):
) )
self.authorization = CoinJoinAuthorization(self.msg_auth) self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache_codec.start_session() if not utils.USE_THP:
storage.cache_codec.start_session()
def test_ownership_proof_account_depth_mismatch(self): def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch. # Account depth mismatch.

View File

@ -1,7 +1,7 @@
# flake8: noqa: F403,F405 # flake8: noqa: F403,F405
from common import * # isort:skip from common import * # isort:skip
from storage import cache_codec, cache_common from storage import cache_common
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.wire import context from trezor.wire import context
@ -9,20 +9,38 @@ from trezor.wire.codec.codec_context import CodecContext
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
def setUpClass(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUpClass(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self): def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin") coin = _get_coin_by_name("Bitcoin")
keychain = await_result(_get_keychain_for_coin(coin)) keychain = await_result(_get_keychain_for_coin(coin))
@ -98,18 +116,30 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase): class TestAltcoinKeychains(unittest.TestCase):
if utils.USE_THP:
def setUpClass(self): def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self): def test_bcash(self):
coin = _get_coin_by_name("Bcash") coin = _get_coin_by_name("Bcash")
keychain = await_result(_get_keychain_for_coin(coin)) keychain = await_result(_get_keychain_for_coin(coin))

View File

@ -2,7 +2,7 @@
from common import * # isort:skip from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache, cache_codec, cache_common from storage import cache, cache_common
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
@ -13,18 +13,32 @@ from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema from apps.common.paths import PATTERN_SEP5, PathSchema
if utils.USE_THP:
import thp_common
if not utils.USE_THP:
from storage import cache_codec
class TestKeychain(unittest.TestCase): class TestKeychain(unittest.TestCase):
def setUpClass(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUpClass(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUp(self):
cache_codec.start_session()
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
def setUp(self):
cache_codec.start_session()
def tearDown(self): def tearDown(self):
cache.clear_all() cache.clear_all()

View File

@ -3,7 +3,7 @@ from common import * # isort:skip
import unittest import unittest
from storage import cache_codec, cache_common from storage import cache_common
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.wire import context from trezor.wire import context
@ -12,6 +12,12 @@ from trezor.wire.codec.codec_context import CodecContext
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network from ethereum_common import encode_network, make_network
from trezor.messages import ( from trezor.messages import (
@ -74,17 +80,30 @@ class TestEthereumKeychain(unittest.TestCase):
addr, addr,
) )
def setUpClass(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUpClass(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def tearDownClass(self): def tearDownClass(self):
context.CURRENT_CONTEXT = None context.CURRENT_CONTEXT = None
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n): def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n) slip44 = _slip44_from_address_n(address_n)
network = make_network(slip44=slip44) network = make_network(slip44=slip44)

View File

@ -1,241 +1,519 @@
# flake8: noqa: F403,F405 # flake8: noqa: F403,F405
from common import * # isort:skip from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache, cache_codec, cache_common
from trezor.messages import EndSession, Initialize
from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.base import handle_EndSession, handle_Initialize
from apps.common.cache import stored, stored_async
KEY = 0 KEY = 0
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from storage import cache, cache_thp
from trezor.wire.thp import ChannelState
from trezor.wire.thp.session_context import SessionContext
# Function moved from cache.py, as it was not used there _PROTOCOL_CACHE = cache_thp
def is_session_started() -> bool:
return cache_codec._active_session_idx is not None else:
from storage import cache, cache_codec
from trezor.messages import EndSession, Initialize
from apps.base import handle_EndSession
from mock_storage import mock_storage
_PROTOCOL_CACHE = cache_codec
def is_session_started() -> bool:
return cache_codec.get_active_session() is not None
def get_active_session():
return cache_codec.get_active_session()
class TestStorageCache(unittest.TestCase): class TestStorageCache(unittest.TestCase):
def setUpClass(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def tearDownClass(self): def setUpClass(self):
context.CURRENT_CONTEXT = None if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self): def setUp(self):
cache.clear_all() self.interface = MockHID(0xDEADBEEF)
cache.clear_all()
def test_start_session(self): def test_new_channel_and_session(self):
session_id_a = cache_codec.start_session() channel = thp_common.get_new_channel(self.interface)
self.assertIsNotNone(session_id_a)
session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all() # Assert that channel is created without any sessions
with self.assertRaises(cache_common.InvalidSessionError): self.assertEqual(len(channel.sessions), 0)
context.cache_set(KEY, "something")
with self.assertRaises(cache_common.InvalidSessionError):
context.cache_get(KEY)
def test_end_session(self): cid_1 = channel.channel_id
session_id = cache_codec.start_session() session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
self.assertTrue(is_session_started()) session_1 = SessionContext(channel, session_cache_1)
context.cache_set(KEY, b"A") self.assertEqual(session_1.channel_id, cid_1)
cache_codec.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
# ending an ended session should be a no-op session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
cache_codec.end_current_session() session_2 = SessionContext(channel, session_cache_2)
self.assertFalse(is_session_started()) self.assertEqual(session_2.channel_id, cid_1)
self.assertEqual(session_1.channel_id, session_2.channel_id)
self.assertNotEqual(session_1.session_id, session_2.session_id)
session_id_a = cache_codec.start_session(session_id) channel_2 = thp_common.get_new_channel(self.interface)
# original session no longer exists cid_2 = channel_2.channel_id
self.assertNotEqual(session_id_a, session_id) self.assertNotEqual(cid_1, cid_2)
# original session data no longer exists
self.assertIsNone(context.cache_get(KEY))
# create a new session session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
session_id_b = cache_codec.start_session() session_3 = SessionContext(channel_2, session_cache_3)
# switch back to original session self.assertEqual(session_3.channel_id, cid_2)
session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache_codec.end_current_session()
# switch back to B
session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self): # Sessions 1 and 3 should have different channel_id, but the same session_id
session_id = cache_codec.start_session() self.assertNotEqual(session_1.channel_id, session_3.channel_id)
self.assertEqual(cache_codec.start_session(session_id), session_id) self.assertEqual(session_1.session_id, session_3.session_id)
context.cache_set(KEY, b"A")
for _ in range(cache_codec._MAX_SESSIONS_COUNT): self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
# Check that session data IS in cache for created sessions ONLY
for i in range(3):
self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
# Check that session data IS NOT in cache after cache.clear_all()
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.channel_id, b"")
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.state, b"\x00")
def test_channel_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
channels = []
for i in range(cache_thp._MAX_CHANNELS_COUNT):
channels.append(thp_common.get_new_channel(self.interface))
channel_ids = [channel.channel_cache.channel_id for channel in channels]
# Assert that each channel_id is unique and that cache and list of channels
# have the same "channels" on the same indexes
for i in range(len(channel_ids)):
self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
for j in range(i + 1, len(channel_ids)):
self.assertNotEqual(channel_ids[i], channel_ids[j])
# Create a new channel that is over the capacity
new_channel = thp_common.get_new_channel(self.interface)
for c in channels:
self.assertNotEqual(c.channel_id, new_channel.channel_id)
# Test that the oldest (least used) channel was replaced (_CHANNELS[0])
self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
# Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
# assert that it is not replaced when creating a new channel
cache_thp.update_channel_last_used(channel_ids[1])
new_new_channel = thp_common.get_new_channel(self.interface)
self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
# Assert that it was in fact the _CHANNEL[2] that was replaced
self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
self.assertEqual(
cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
)
def test_session_capacity_in_cache(self):
self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
sesions_A = []
cid = []
sid = []
for i in range(3):
sesions_A.append(cache_thp.get_new_session(channel_cache_A))
cid.append(sesions_A[i].channel_id)
sid.append(sesions_A[i].session_id)
sessions_B = []
for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
sessions_B.append(cache_thp.get_new_session(channel_cache_B))
for i in range(3):
self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
# Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
new_session = cache_thp.get_new_session(channel_cache_B)
self.assertEqual(new_session, cache_thp._SESSIONS[0])
self.assertNotEqual(new_session.channel_id, cid[0])
self.assertNotEqual(new_session.session_id, sid[0])
# Assert that updating "last used" for session on channel A increases also
# the "last usage" of channel A.
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
cache_thp.update_session_last_used(
channel_cache_A.channel_id, sesions_A[1].session_id
)
self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
new_new_session = cache_thp.get_new_session(channel_cache_B)
# Assert that creating a new session on channel B shifts the "last usage" again
# and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
def test_clear(self):
channel_A = thp_common.get_new_channel(self.interface)
channel_B = thp_common.get_new_channel(self.interface)
cid_A = channel_A.channel_id
cid_B = channel_B.channel_id
sessions = []
for i in range(3):
sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
# Assert that clearing of channel A works
self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
channel_A.clear()
self.assertEqual(channel_A.channel_cache.channel_id, b"")
self.assertEqual(channel_A.channel_cache.last_usage, 0)
self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
# Assert that clearing channel A also cleared all its sessions
for i in range(3):
self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
cache.clear_all()
for session in cache_thp._SESSIONS:
self.assertEqual(session.last_usage, 0)
self.assertEqual(session.channel_id, b"")
for channel in cache_thp._CHANNELS:
self.assertEqual(channel.channel_id, b"")
self.assertEqual(channel.last_usage, 0)
self.assertEqual(
cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
)
def test_get_set(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set(KEY, b"world")
self.assertEqual(session_2.get(KEY), b"world")
self.assertEqual(session_1.get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(session_1.get(KEY))
self.assertIsNone(session_2.get(KEY))
def test_get_set_int(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
session_1.set_int(KEY, 1234)
self.assertEqual(session_1.get_int(KEY), 1234)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.set_int(KEY, 5678)
self.assertEqual(session_2.get_int(KEY), 5678)
self.assertEqual(session_1.get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(session_1.get_int(KEY))
self.assertIsNone(session_2.get_int(KEY))
def test_get_set_bool(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
with self.assertRaises(AssertionError):
session_1.set_bool(KEY, True)
# Change length of first session field to 0 so that the length check passes
session_1.fields = (0,) + session_1.fields[1:]
# with self.assertRaises(AssertionError) as e:
session_1.set_bool(KEY, True)
self.assertEqual(session_1.get_bool(KEY), True)
session_2 = cache_thp.get_new_session(channel.channel_cache)
session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
session_2.set_bool(KEY, False)
self.assertEqual(session_2.get_bool(KEY), False)
self.assertEqual(session_1.get_bool(KEY), True)
cache.clear_all()
# Default value is False
self.assertFalse(session_1.get_bool(KEY))
self.assertFalse(session_2.get_bool(KEY))
def test_delete(self):
channel = thp_common.get_new_channel(self.interface)
session_1 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
self.assertEqual(session_1.get(KEY), b"hello")
session_1.delete(KEY)
self.assertIsNone(session_1.get(KEY))
session_1.set(KEY, b"hello")
session_2 = cache_thp.get_new_session(channel.channel_cache)
self.assertIsNone(session_2.get(KEY))
session_2.set(KEY, b"hello")
self.assertEqual(session_2.get(KEY), b"hello")
session_2.delete(KEY)
self.assertIsNone(session_2.get(KEY))
self.assertEqual(session_1.get(KEY), b"hello")
else:
def setUpClass(self):
from trezor.wire.codec.codec_context import CodecContext
from trezor.wire import context
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def tearDownClass(self):
from trezor.wire import context
context.CURRENT_CONTEXT = None
def setUp(self):
cache.clear_all()
def test_start_session(self):
session_id_a = cache_codec.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
self.assertIsNone(get_active_session())
for session in cache_codec._SESSIONS:
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
def test_end_session(self):
session_id = cache_codec.start_session()
self.assertTrue(is_session_started())
get_active_session().set(KEY, b"A")
cache_codec.end_current_session()
self.assertFalse(is_session_started())
self.assertIsNone(get_active_session())
# ending an ended session should be a no-op
cache_codec.end_current_session()
self.assertFalse(is_session_started())
session_id_a = cache_codec.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(get_active_session().get(KEY))
# create a new session
session_id_b = cache_codec.start_session()
# switch back to original session
session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache_codec.end_current_session()
# switch back to B
session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self):
session_id = cache_codec.start_session()
self.assertEqual(cache_codec.start_session(session_id), session_id)
get_active_session().set(KEY, b"A")
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache_codec.start_session()
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
self.assertIsNone(get_active_session().get(KEY))
def test_get_set(self):
session_id1 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"hello")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
session_id2 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"world")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id2)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id1)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
cache.clear_all()
self.assertIsNone(cache_codec.get_active_session())
def test_get_set_int(self):
session_id1 = cache_codec.start_session()
get_active_session().set_int(KEY, 1234)
self.assertEqual(get_active_session().get_int(KEY), 1234)
session_id2 = cache_codec.start_session()
get_active_session().set_int(KEY, 5678)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id2)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get_int(KEY), 1234)
cache.clear_all()
self.assertIsNone(get_active_session())
def test_delete(self):
session_id1 = cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
cache_codec.start_session() cache_codec.start_session()
self.assertNotEqual(cache_codec.start_session(session_id), session_id) self.assertIsNone(get_active_session().get(KEY))
self.assertIsNone(context.cache_get(KEY)) get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
def test_get_set(self): cache_codec.start_session(session_id1)
session_id1 = cache_codec.start_session() self.assertEqual(get_active_session().get(KEY), b"hello")
context.cache_set(KEY, b"hello")
self.assertEqual(context.cache_get(KEY), b"hello")
session_id2 = cache_codec.start_session() def test_decorators(self):
context.cache_set(KEY, b"world") run_count = 0
self.assertEqual(context.cache_get(KEY), b"world") cache_codec.start_session()
from apps.common.cache import stored
cache_codec.start_session(session_id2) @stored(KEY)
self.assertEqual(context.cache_get(KEY), b"world") def func():
cache_codec.start_session(session_id1) nonlocal run_count
self.assertEqual(context.cache_get(KEY), b"hello") run_count += 1
return b"foo"
cache.clear_all() # cache is empty
with self.assertRaises(cache_common.InvalidSessionError): self.assertIsNone(get_active_session().get(KEY))
context.cache_get(KEY) self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(get_active_session().get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
def test_get_set_int(self): def test_empty_value(self):
session_id1 = cache_codec.start_session() cache_codec.start_session()
context.cache_set_int(KEY, 1234)
self.assertEqual(context.cache_get_int(KEY), 1234)
session_id2 = cache_codec.start_session() self.assertIsNone(get_active_session().get(KEY))
context.cache_set_int(KEY, 5678) get_active_session().set(KEY, b"")
self.assertEqual(context.cache_get_int(KEY), 5678) self.assertEqual(get_active_session().get(KEY), b"")
cache_codec.start_session(session_id2) get_active_session().delete(KEY)
self.assertEqual(context.cache_get_int(KEY), 5678) run_count = 0
cache_codec.start_session(session_id1)
self.assertEqual(context.cache_get_int(KEY), 1234)
cache.clear_all() from apps.common.cache import stored
with self.assertRaises(cache_common.InvalidSessionError):
context.cache_get_int(KEY)
def test_delete(self): @stored(KEY)
session_id1 = cache_codec.start_session() def func():
self.assertIsNone(context.cache_get(KEY)) nonlocal run_count
context.cache_set(KEY, b"hello") run_count += 1
self.assertEqual(context.cache_get(KEY), b"hello") return b""
context.cache_delete(KEY)
self.assertIsNone(context.cache_get(KEY))
context.cache_set(KEY, b"hello") self.assertEqual(func(), b"")
cache_codec.start_session() # function gets called once
self.assertIsNone(context.cache_get(KEY)) self.assertEqual(run_count, 1)
context.cache_set(KEY, b"hello") self.assertEqual(func(), b"")
self.assertEqual(context.cache_get(KEY), b"hello") # function is not called for a second time
context.cache_delete(KEY) self.assertEqual(run_count, 1)
self.assertIsNone(context.cache_get(KEY))
cache_codec.start_session(session_id1) @mock_storage
self.assertEqual(context.cache_get(KEY), b"hello") def test_Initialize(self):
from apps.base import handle_Initialize
def test_decorators(self): def call_Initialize(**kwargs):
run_count = 0 msg = Initialize(**kwargs)
cache_codec.start_session() return await_result(handle_Initialize(msg))
@stored(KEY) # calling Initialize without an ID allocates a new one
def func(): session_id = cache_codec.start_session()
nonlocal run_count features = call_Initialize()
run_count += 1 self.assertNotEqual(session_id, features.session_id)
return b"foo"
# cache is empty # calling Initialize with the current ID does not allocate a new one
self.assertIsNone(context.cache_get(KEY)) features = call_Initialize(session_id=session_id)
self.assertEqual(run_count, 0) self.assertEqual(session_id, features.session_id)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(context.cache_get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
@stored_async(KEY) # store "hello"
async def async_func(): get_active_session().set(KEY, b"hello")
nonlocal run_count # check that it is cleared
run_count += 1 features = call_Initialize()
return b"bar" session_id = features.session_id
self.assertIsNone(get_active_session().get(KEY))
# store "hello" again
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
# cache is still full # supplying a different session ID starts a new session
self.assertEqual(await_result(async_func()), b"foo") call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertEqual(run_count, 1) self.assertIsNone(get_active_session().get(KEY))
cache_codec.start_session() # but resuming a session loads the previous one
self.assertEqual(await_result(async_func()), b"bar") call_Initialize(session_id=session_id)
self.assertEqual(run_count, 2) self.assertEqual(get_active_session().get(KEY), b"hello")
# awaitable is also run only once
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
def test_empty_value(self): def test_EndSession(self):
cache_codec.start_session()
self.assertIsNone(context.cache_get(KEY)) self.assertIsNone(get_active_session())
context.cache_set(KEY, b"") cache_codec.start_session()
self.assertEqual(context.cache_get(KEY), b"") self.assertTrue(is_session_started())
self.assertIsNone(get_active_session().get(KEY))
context.cache_delete(KEY) await_result(handle_EndSession(EndSession()))
run_count = 0 self.assertFalse(is_session_started())
self.assertIsNone(cache_codec.get_active_session())
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
# calling Initialize without an ID allocates a new one
session_id = cache_codec.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# store "hello"
context.cache_set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(context.cache_get(KEY))
# store "hello" again
context.cache_set(KEY, b"hello")
self.assertEqual(context.cache_get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH)
self.assertIsNone(context.cache_get(KEY))
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(context.cache_get(KEY), b"hello")
def test_EndSession(self):
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
cache_codec.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(context.cache_get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,28 +3,11 @@ from common import * # isort:skip
import ustruct import ustruct
from mock_wire_interface import MockHID
from trezor import io from trezor import io
from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire.codec import codec_v1 from trezor.wire.codec import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242 MESSAGE_TYPE = 0x4242
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL") HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")

View File

@ -0,0 +1,94 @@
from common import * # isort:skip
if utils.USE_THP:
from trezor.wire.thp import checksum
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolChecksum(unittest.TestCase):
vectors_correct = [
(
b"",
b"\x00\x00\x00\x00",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAD",
),
(
bytes("a", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("123456789", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"12345678901234567890123456789012345678901234567890123456789012345678901234567890",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
(
b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
b"\x9B\xD3\x66\xAE",
),
(
b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
b"\x6B\xA4\xEC\x92",
),
]
vectors_incorrect = [
(
b"",
b"\x00\x00\x00\x00\x00",
),
(
b"",
b"",
),
(
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
b"\x19\x0A\x55\xAE",
),
(
bytes("A", "ascii"),
b"\xE8\xB7\xBE\x43",
),
(
bytes("abc ", "ascii"),
b"\x35\x24\x41\xC2",
),
(
bytes("1234567890", "ascii"),
b"\xCB\xF4\x39\x26",
),
(
bytes(
"1234567890123456789012345678901234567890123456789012345678901234567890123456789",
"ascii",
),
b"\x7C\xA9\x4A\x72",
),
]
def test_computation(self):
for data, chksum in self.vectors_correct:
self.assertEqual(checksum.compute(data), chksum)
def test_validation_correct(self):
for data, chksum in self.vectors_correct:
self.assertTrue(checksum.is_valid(chksum, data))
def test_validation_incorrect(self):
for data, chksum in self.vectors_incorrect:
self.assertFalse(checksum.is_valid(chksum, data))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,156 @@
from common import * # isort:skip
from trezorcrypto import aesgcm, curve25519
import storage
if utils.USE_THP:
import thp_common
from trezor.wire.thp import crypto
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
def get_dummy_device_secret():
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCrypto(unittest.TestCase):
if utils.USE_THP:
handshake = Handshake()
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
vectors_enc = [
(
key_1,
0,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"e2c9dd152fbee5821ea7",
b"10625812de81b14a46b9f1e5100a6d0c",
),
(
key_1,
1,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"79811619ddb07c2b99f8",
b"71c6b872cdc499a7e9a3c7441f053214",
),
(
key_1,
369,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"c1200f8a7ae9a6d32cef0fff878d55c2",
),
(
key_1,
369,
b"\x55\x64\x73\x82\x91",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"693ac160cd93a20f7fc255f049d808d0",
),
]
# 0:chaining key, 1:input, 2:output_1, 3:output:2
vectors_hkdf = [
(
crypto.PROTOCOL_NAME,
b"\x01\x02",
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
),
(
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
),
]
vectors_iv = [
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
]
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
utils.DISABLE_ENCRYPTION = False
def test_encryption(self):
for v in self.vectors_enc:
buffer = bytearray(v[3])
tag = crypto.enc(buffer, v[0], v[1], v[2])
self.assertEqual(hexlify(buffer), v[4])
self.assertEqual(hexlify(tag), v[5])
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
self.assertEqual(buffer, v[3])
def test_hkdf(self):
for v in self.vectors_hkdf:
ck, k = crypto._hkdf(v[0], v[1])
self.assertEqual(hexlify(ck), v[2])
self.assertEqual(hexlify(k), v[3])
def test_iv_from_nonce(self):
for v in self.vectors_iv:
x = v[0]
y = x.to_bytes(8, "big")
iv = crypto._get_iv_from_nonce(v[0])
self.assertEqual(iv, v[1])
with self.assertRaises(AssertionError) as e:
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
def test_incorrect_vectors(self):
pass
def test_th1_crypto(self):
storage.device.get_device_secret = get_dummy_device_secret
handshake = self.handshake
host_ephemeral_privkey = curve25519.generate_secret()
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
def test_th2_crypto(self):
handshake = self.handshake
host_static_privkey = curve25519.generate_secret()
host_static_pubkey = curve25519.publickey(host_static_privkey)
aes_ctx = aesgcm(handshake.k, IV_2)
aes_ctx.auth(handshake.h)
encrypted_host_static_pubkey = bytearray(
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
)
# Code to encrypt Host's noise encrypted payload correctly:
protomsg = bytearray(b"\x10\x02\x10\x03")
temp_k = handshake.k
temp_h = handshake.h
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
_, temp_k = crypto._hkdf(
handshake.ck,
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(temp_k, IV_1)
aes_ctx.encrypt_in_place(protomsg)
aes_ctx.auth(temp_h)
tag = aes_ctx.finish()
encrypted_payload = bytearray(protomsg + tag)
# end of encrypted payload generation
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,378 @@
from common import * # isort:skip
from mock_wire_interface import MockHID
from trezor import config, io, protobuf
from trezor.crypto.curve import curve25519
from trezor.enums import ThpMessageType
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.protocol_common import Message
if utils.USE_THP:
from typing import TYPE_CHECKING
import thp_common
from storage import cache_thp
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND,
)
from trezor.crypto import elligator2
from trezor.enums import ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCpaceHost,
ThpCodeEntryTag,
ThpCredentialRequest,
ThpEndRequest,
ThpStartPairingRequest,
)
from trezor.wire.thp import thp_main
from trezor.wire.thp import ChannelState, checksum, interface_manager
from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext
from apps.thp import pairing
if TYPE_CHECKING:
from trezor.wire import WireInterface
def get_dummy_key() -> bytes:
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
def dummy_encode_iface(iface: WireInterface):
return thp_common._MOCK_INTERFACE_HID
def send_channel_allocation_request(
interface: WireInterface, nonce: bytes | None = None
) -> bytes:
if nonce is None or len(nonce) != 8:
nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
header = b"\x40\xff\xff\x00\x0c"
chksum = checksum.compute(header + nonce)
cid_req = header + nonce + chksum
gen = thp_main.thp_main_loop(interface)
expected_channel_index = cache_thp._get_next_channel_index()
gen.send(None)
gen.send(cid_req)
gen.send(None)
model = bytes(utils.INTERNAL_MODEL, "big")
response_data = (
b"\x0a\x04" + model + "\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
)
response_without_crc = (
b"\x41\xff\xff\x00\x20"
+ nonce
+ cache_thp._CHANNELS[expected_channel_index].channel_id
+ response_data
)
chkcsum = checksum.compute(response_without_crc)
expected_response = response_without_crc + chkcsum + b"\x00" * 27
return expected_response
def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
return int.from_bytes(channel_allocation_response[13:15], "big")
def get_ack(channel_id: bytes) -> bytes:
if len(channel_id) != 2:
raise Exception("Channel id should by two bytes long")
return (
b"\x20"
+ channel_id
+ b"\x00\x04"
+ checksum.compute(b"\x20" + channel_id + b"\x00\x04")
+ b"\x00" * 55
)
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocol(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
interface_manager.encode_iface = dummy_encode_iface
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
buffer2 = bytearray(256)
thp_main.set_read_buffer(buffer)
thp_main.set_write_buffer(buffer2)
interface_manager.decode_iface = thp_common.dummy_decode_iface
def test_codec_message(self):
self.assertEqual(len(self.interface.data), 0)
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# There should be a failiure response to received init packet (starts with "?##")
test_codec_message = b"?## Some data"
gen.send(test_codec_message)
gen.send(None)
self.assertEqual(len(self.interface.data), 1)
expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10"
self.assertEqual(
self.interface.data[-1][: len(expected_response)], expected_response
)
# There should be no response for continuation packet (starts with "?" only)
test_codec_message_2 = b"? Cont packet"
gen.send(test_codec_message_2)
with self.assertRaises(TypeError) as e:
gen.send(None)
self.assertEqual(e.value.value, "object with buffer protocol required")
self.assertEqual(len(self.interface.data), 1)
def test_message_on_unallocated_channel(self):
gen = thp_main.thp_main_loop(self.interface)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
message_to_channel_789a = (
b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
gen.send(message_to_channel_789a)
gen.send(None)
unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
self.assertEqual(
utils.get_bytes_as_str(self.interface.data[-1]),
unallocated_chanel_error_on_channel_789a,
)
def test_channel_allocation(self):
self.assertEqual(len(thp_main._CHANNELS), 0)
for c in cache_thp._CHANNELS:
self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED)
expected_channel_index = cache_thp._get_next_channel_index()
expected_response = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-1], expected_response)
cid = cache_thp._CHANNELS[expected_channel_index].channel_id
self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS)
self.assertEqual(len(thp_main._CHANNELS), 1)
# test channel's default state is TH1:
cid = get_channel_id_from_response(self.interface.data[-1])
self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
def test_invalid_encrypted_tag(self):
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
def test_channel_errors(self):
gen = thp_main.thp_main_loop(self.interface)
gen.send(None)
# prepare 2 new channels
expected_response_1 = send_channel_allocation_request(self.interface)
expected_response_2 = send_channel_allocation_request(self.interface)
self.assertEqual(self.interface.data[-2], expected_response_1)
self.assertEqual(self.interface.data[-1], expected_response_2)
# test invalid encryption tag
config.init()
config.wipe()
cid_1 = get_channel_id_from_response(expected_response_1)
channel = thp_main._CHANNELS[cid_1]
channel.iface = self.interface
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
header = b"\x04" + channel.channel_id + b"\x00\x14"
tag = b"\x00" * 16
chksum = checksum.compute(header + tag)
message_with_invalid_tag = header + tag + chksum
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
cid_1_bytes = int.to_bytes(cid_1, 2, "big")
expected_ack_on_received_message = get_ack(cid_1_bytes)
gen.send(message_with_invalid_tag)
gen.send(None)
self.assertEqual(
self.interface.data[-1],
expected_ack_on_received_message,
)
error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
chksum_err = checksum.compute(error_without_crc)
gen.send(None)
decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
self.assertEqual(
self.interface.data[-1],
decryption_failed_error,
)
# test invalid tag in handshake phase
cid_2 = get_channel_id_from_response(expected_response_1)
cid_2_bytes = cid_2.to_bytes(2, "big")
channel = thp_main._CHANNELS[cid_2]
channel.iface = self.interface
channel.set_channel_state(ChannelState.TH2)
message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
# gen.send(message_with_invalid_tag)
# gen.send(None)
# gen.send(None)
# for i in self.interface.data:
# print(utils.get_bytes_as_str(i))
def test_skip_pairing(self):
config.init()
config.wipe()
channel = next(iter(thp_main._CHANNELS.values()))
channel.selected_pairing_methods = [
ThpPairingMethod.NoMethod,
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
with self.assertRaises(StopIteration):
gen.send(None)
self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
# Teardown: set back initial channel state value
channel.set_channel_state(ChannelState.TH1)
def TODO_test_pairing(self):
config.init()
config.wipe()
cid = get_channel_id_from_response(
send_channel_allocation_request(self.interface)
)
channel = thp_main._CHANNELS[cid]
channel.selected_pairing_methods = [
ThpPairingMethod.CodeEntry,
ThpPairingMethod.NFC_Unidirectional,
ThpPairingMethod.QrCode,
]
pairing_ctx = PairingContext(channel)
request_message = ThpStartPairingRequest()
with self.assertRaises(UnexpectedMessage) as e:
pairing.handle_pairing_request(pairing_ctx, request_message)
print(e.value.message)
channel.set_channel_state(ChannelState.TP1)
gen = pairing.handle_pairing_request(pairing_ctx, request_message)
channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
gen.send(None)
async def _dummy(ctx: PairingContext, expected_types):
return await ctx.read([1018, 1024])
pairing.show_display_data = _dummy
msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
protobuf.encode(buffer, msg_code_entry)
code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer)
gen.send(code_entry_challenge)
# tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
# tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
generator_host = elligator2.map_to_curve25519(pregenerator_host)
cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
cpace_host_public_key: bytes = curve25519.multiply(
cpace_host_private_key, generator_host
)
msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
# msg = ThpQrCodeTag(tag=tag_qrc)
# msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer)
gen.send(user_message)
tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
msg = ThpCodeEntryTag(tag=tag_ent)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer)
gen.send(user_message)
host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer)
gen.send(credential_request)
msg = ThpEndRequest()
buffer: bytearray = bytearray(protobuf.encoded_length(msg))
protobuf.encode(buffer, msg)
end_request = Message(1012, buffer)
with self.assertRaises(StopIteration) as e:
gen.send(end_request)
print("response message:", e.value.value.MESSAGE_NAME)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,151 @@
from common import * # isort:skip
from typing import Any, Awaitable
if utils.USE_THP:
import thp_common
from mock_wire_interface import MockHID
from trezor.wire.thp import writer
from trezor.wire.thp import ENCRYPTED, PacketHeader
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolWriter(unittest.TestCase):
short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
eight_longer_payloads_expected = [
b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
]
empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
longer_payload_with_checksum_expected = [
b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
]
def await_until_result(self, task: Awaitable) -> Any:
with self.assertRaises(StopIteration):
while True:
task.send(None)
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_write_empty_packet(self):
self.await_until_result(writer.write_packet_to_wire(self.interface, b""))
print(self.interface.data[0])
self.assertEqual(len(self.interface.data), 1)
self.assertEqual(self.interface.data[0], b"")
def test_write_empty_payload(self):
header = PacketHeader(ENCRYPTED, 4660, 4)
await_result(writer.write_payloads_to_wire(self.interface, header, (b"",)))
self.assertEqual(len(self.interface.data), 0)
def test_write_short_payload(self):
header = PacketHeader(ENCRYPTED, 4660, 5)
data = b"\x07"
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
def test_write_longer_payload(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 256)
self.await_until_result(
writer.write_payloads_to_wire(self.interface, header, (data,))
)
for i in range(len(self.longer_payload_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.longer_payload_expected[i]
)
def test_write_eight_longer_payloads(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 2048)
self.await_until_result(
writer.write_payloads_to_wire(
self.interface, header, (data, data, data, data, data, data, data, data)
)
)
for i in range(len(self.eight_longer_payloads_expected)):
self.assertEqual(
hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
)
def test_write_empty_payload_with_checksum(self):
header = PacketHeader(ENCRYPTED, 4660, 4)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
)
self.assertEqual(
hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
)
def test_write_longer_payload_with_checksum(self):
data = bytearray(range(256))
header = PacketHeader(ENCRYPTED, 4660, 256)
self.await_until_result(
writer.write_payload_to_wire_and_add_checksum(self.interface, header, data)
)
for i in range(len(self.longer_payload_with_checksum_expected)):
self.assertEqual(
hexlify(self.interface.data[i]),
self.longer_payload_with_checksum_expected[i],
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,338 @@
from common import * # isort:skip
import ustruct
from typing import TYPE_CHECKING
from mock_wire_interface import MockHID
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io
from trezor.utils import chunks
from trezor.wire.protocol_common import Message
if utils.USE_THP:
import thp_common
import trezor.wire.thp
from trezor.wire.thp import thp_main
from trezor.wire.thp import alternating_bit_protocol as ABP
from trezor.wire.thp import checksum
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.writer import PACKET_LENGTH
if TYPE_CHECKING:
from trezorio import WireInterface
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
if utils.USE_THP:
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def getPlaintext() -> bytes:
if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
return PLAINTEXT_0
async def deprecated_read_message(
iface: WireInterface, buffer: utils.BufferType
) -> Message:
return Message(-1, b"\x00")
async def deprecated_write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
pass
# This test suite is an adaptation of test_trezor.wire.codec_v1
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req_message)
gen.send(None)
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(message)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def _roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, 1
) # TODO use different session id
gen = deprecated_write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
print(utils.get_bytes_as_str(packet))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def _write_one_packet(self):
message = Message(MESSAGE_TYPE, b"")
gen = deprecated_write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(MESSAGE_TYPE, message_payload)
gen = deprecated_write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = deprecated_read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(trezor.wire.thp.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()

44
core/tests/thp_common.py Normal file
View File

@ -0,0 +1,44 @@
from trezor import utils
from trezor.wire.thp import ChannelState
if utils.USE_THP:
import unittest
from typing import TYPE_CHECKING
from mock_wire_interface import MockHID
from storage import cache_thp
from trezor.wire import context
from trezor.wire.thp import interface_manager
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.session_context import SessionContext
_MOCK_INTERFACE_HID = b"\x00"
if TYPE_CHECKING:
from trezor.wire import WireInterface
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
interface_manager.decode_iface = dummy_decode_iface
channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
channel = Channel(channel_cache)
channel.set_channel_state(ChannelState.TH1)
if channel_iface is not None:
channel.iface = channel_iface
return channel
def prepare_context() -> None:
channel = get_new_channel()
session_cache = cache_thp.get_new_session(channel.channel_cache)
session_ctx = SessionContext(channel, session_cache)
context.CURRENT_CONTEXT = session_ctx
if __debug__:
# Disable log.debug
def suppres_debug_log() -> None:
from trezor import log
log.debug = lambda name, msg, *args: None

View File

@ -2,7 +2,7 @@
import binascii import binascii
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate() devices = HidTransport.enumerate()
if len(devices) > 0: if len(devices) > 0: