refactor(core): get rid of passing Context around

pull/3138/head
matejcik 12 months ago committed by matejcik
parent fe80793b47
commit 8c5c2f4204

@ -141,7 +141,7 @@ def get_features() -> Features:
return f return f
async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features: async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id) session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
@ -170,20 +170,20 @@ async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features:
return features return features
async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features: async def handle_GetFeatures(msg: GetFeatures) -> Features:
return get_features() return get_features()
async def handle_Cancel(ctx: wire.Context, msg: Cancel) -> Success: async def handle_Cancel(msg: Cancel) -> Success:
raise wire.ActionCancelled raise wire.ActionCancelled
async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success: async def handle_LockDevice(msg: LockDevice) -> Success:
lock_device() lock_device()
return Success() return Success()
async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success: async def handle_SetBusy(msg: SetBusy) -> Success:
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
@ -199,24 +199,23 @@ async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success:
return Success() return Success()
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success: async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session() storage_cache.end_current_session()
return Success() return Success()
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success: async def handle_Ping(msg: Ping) -> Success:
if msg.button_protection: if msg.button_protection:
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
from trezor.enums import ButtonRequestType as B from trezor.enums import ButtonRequestType as B
await confirm_action(ctx, "ping", "Confirm", "ping", br_code=B.ProtectCall) await confirm_action("ping", "Confirm", "ping", br_code=B.ProtectCall)
return Success(message=msg.message) return Success(message=msg.message)
async def handle_DoPreauthorized( async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
ctx: wire.Context, msg: DoPreauthorized
) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context
from apps.common import authorization from apps.common import authorization
if not authorization.is_set(): if not authorization.is_set():
@ -225,22 +224,23 @@ async def handle_DoPreauthorized(
wire_types = authorization.get_wire_types() wire_types = authorization.get_wire_types()
utils.ensure(bool(wire_types), "Unsupported preauthorization found") utils.ensure(bool(wire_types), "Unsupported preauthorization found")
req = await ctx.call_any(PreauthorizedRequest(), *wire_types) req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE get_context().iface, req.MESSAGE_WIRE_TYPE
) )
if handler is None: if handler is None:
return wire.unexpected_message() return wire.unexpected_message()
return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments] return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.MessageType: async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest from trezor.messages import UnlockedPathRequest
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
from trezor.wire.context import call_any, get_context
from apps.common.paths import SLIP25_PURPOSE from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed from apps.common.seed import Slip21Node, get_seed
from apps.common.writers import write_uint32_le from apps.common.writers import write_uint32_le
@ -253,7 +253,7 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
if msg.address_n != [SLIP25_PURPOSE]: if msg.address_n != [SLIP25_PURPOSE]:
raise wire.DataError("Invalid path") raise wire.DataError("Invalid path")
seed = await get_seed(ctx) seed = await get_seed()
node = Slip21Node(seed) node = Slip21Node(seed)
node.derive_path(_KEYCHAIN_MAC_KEY_PATH) node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
mac = utils.HashWriter(hmac(hmac.SHA256, node.key())) mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
@ -269,7 +269,6 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
raise wire.DataError("Invalid MAC") raise wire.DataError("Invalid MAC")
else: else:
await confirm_action( await confirm_action(
ctx,
"confirm_coinjoin_access", "confirm_coinjoin_access",
title="Coinjoin", title="Coinjoin",
description="Access your coinjoin account?", description="Access your coinjoin account?",
@ -277,19 +276,17 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
) )
wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx) wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx)
req = await ctx.call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE get_context().iface, req.MESSAGE_WIRE_TYPE
) )
assert handler is not None assert handler is not None
return await handler(ctx, req, msg) # type: ignore [Expected 2 positional arguments] return await handler(req, msg) # type: ignore [Expected 1 positional argument]
async def handle_CancelAuthorization( async def handle_CancelAuthorization(msg: CancelAuthorization) -> protobuf.MessageType:
ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType:
from apps.common import authorization from apps.common import authorization
authorization.clear() authorization.clear()
@ -337,7 +334,7 @@ def lock_device_if_unlocked() -> None:
lock_device(interrupt_workflow=workflow.autolock_interrupts_workflow) lock_device(interrupt_workflow=workflow.autolock_interrupts_workflow)
async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None: async def unlock_device() -> None:
"""Ensure the device is in unlocked state. """Ensure the device is in unlocked state.
If the storage is locked, attempt to unlock it. Reset the homescreen and the wire If the storage is locked, attempt to unlock it. Reset the homescreen and the wire
@ -347,7 +344,7 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
if not config.is_unlocked(): if not config.is_unlocked():
# verify_user_pin will raise if the PIN was invalid # verify_user_pin will raise if the PIN was invalid
await verify_user_pin(ctx) await verify_user_pin()
set_homescreen() set_homescreen()
wire.find_handler = workflow_handlers.find_registered_handler wire.find_handler = workflow_handlers.find_registered_handler
@ -369,9 +366,9 @@ def get_pinlocked_handler(
if msg_type in workflow.ALLOW_WHILE_LOCKED: if msg_type in workflow.ALLOW_WHILE_LOCKED:
return orig_handler return orig_handler
async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType: async def wrapper(msg: wire.Msg) -> protobuf.MessageType:
await unlock_device(ctx) await unlock_device()
return await orig_handler(ctx, msg) return await orig_handler(msg)
return wrapper return wrapper
@ -383,7 +380,7 @@ def reload_settings_from_storage() -> None:
workflow.idle_timer.set( workflow.idle_timer.set(
storage_device.get_autolock_delay_ms(), lock_device_if_unlocked storage_device.get_autolock_delay_ms(), lock_device_if_unlocked
) )
wire.experimental_enabled = storage_device.get_experimental_features() wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features()
ui.display.orientation(storage_device.get_rotation()) ui.display.orientation(storage_device.get_rotation())

@ -4,14 +4,11 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import BinanceGetAddress, BinanceAddress from trezor.messages import BinanceGetAddress, BinanceAddress
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_address( async def get_address(msg: BinanceGetAddress, keychain: Keychain) -> BinanceAddress:
ctx: Context, msg: BinanceGetAddress, keychain: Keychain
) -> BinanceAddress:
from trezor.messages import BinanceAddress from trezor.messages import BinanceAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
@ -22,12 +19,12 @@ async def get_address(
HRP = "bnb" HRP = "bnb"
address_n = msg.address_n # local_cache_attribute address_n = msg.address_n # local_cache_attribute
await paths.validate_path(ctx, keychain, address_n) await paths.validate_path(keychain, address_n)
node = keychain.derive(address_n) node = keychain.derive(address_n)
pubkey = node.public_key() pubkey = node.public_key()
address = address_from_public_key(pubkey, HRP) address = address_from_public_key(pubkey, HRP)
if msg.show_display: if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n)) await show_address(address, path=paths.address_n_to_str(address_n))
return BinanceAddress(address=address) return BinanceAddress(address=address)

@ -4,13 +4,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import BinanceGetPublicKey, BinancePublicKey from trezor.messages import BinanceGetPublicKey, BinancePublicKey
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_public_key( async def get_public_key(
ctx: Context, msg: BinanceGetPublicKey, keychain: Keychain msg: BinanceGetPublicKey, keychain: Keychain
) -> BinancePublicKey: ) -> BinancePublicKey:
from ubinascii import hexlify from ubinascii import hexlify
@ -19,11 +18,11 @@ async def get_public_key(
from apps.common import paths from apps.common import paths
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
pubkey = node.public_key() pubkey = node.public_key()
if msg.show_display: if msg.show_display:
await show_pubkey(ctx, hexlify(pubkey).decode()) await show_pubkey(hexlify(pubkey).decode())
return BinancePublicKey(public_key=pubkey) return BinancePublicKey(public_key=pubkey)

@ -13,10 +13,9 @@ if TYPE_CHECKING:
BinanceOrderMsg, BinanceOrderMsg,
BinanceTransferMsg, BinanceTransferMsg,
) )
from trezor.wire import Context
async def require_confirm_transfer(ctx: Context, msg: BinanceTransferMsg) -> None: async def require_confirm_transfer(msg: BinanceTransferMsg) -> None:
items: list[tuple[str, str, str]] = [] items: list[tuple[str, str, str]] = []
def make_input_output_pages(msg: BinanceInputOutput, direction: str) -> None: def make_input_output_pages(msg: BinanceInputOutput, direction: str) -> None:
@ -35,19 +34,16 @@ async def require_confirm_transfer(ctx: Context, msg: BinanceTransferMsg) -> Non
for txoutput in msg.outputs: for txoutput in msg.outputs:
make_input_output_pages(txoutput, "Confirm output") make_input_output_pages(txoutput, "Confirm output")
await _confirm_transfer(ctx, items) await _confirm_transfer(items)
async def _confirm_transfer( async def _confirm_transfer(inputs_outputs: Sequence[tuple[str, str, str]]) -> None:
ctx: Context, inputs_outputs: Sequence[tuple[str, str, str]]
) -> None:
from trezor.ui.layouts import confirm_output from trezor.ui.layouts import confirm_output
for index, (title, amount, address) in enumerate(inputs_outputs): for index, (title, amount, address) in enumerate(inputs_outputs):
# Having hold=True on the last item # Having hold=True on the last item
hold = index == len(inputs_outputs) - 1 hold = index == len(inputs_outputs) - 1
await confirm_output( await confirm_output(
ctx,
address, address,
amount, amount,
title, title,
@ -55,9 +51,8 @@ async def _confirm_transfer(
) )
async def require_confirm_cancel(ctx: Context, msg: BinanceCancelMsg) -> None: async def require_confirm_cancel(msg: BinanceCancelMsg) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_cancel", "confirm_cancel",
"Confirm cancel", "Confirm cancel",
( (
@ -70,7 +65,7 @@ async def require_confirm_cancel(ctx: Context, msg: BinanceCancelMsg) -> None:
) )
async def require_confirm_order(ctx: Context, msg: BinanceOrderMsg) -> None: async def require_confirm_order(msg: BinanceOrderMsg) -> None:
from trezor.enums import BinanceOrderSide from trezor.enums import BinanceOrderSide
if msg.side == BinanceOrderSide.BUY: if msg.side == BinanceOrderSide.BUY:
@ -81,7 +76,6 @@ async def require_confirm_order(ctx: Context, msg: BinanceOrderMsg) -> None:
side = "Unknown" side = "Unknown"
await confirm_properties( await confirm_properties(
ctx,
"confirm_order", "confirm_order",
"Confirm order", "Confirm order",
( (

@ -5,14 +5,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import BinanceSignTx, BinanceSignedTx from trezor.messages import BinanceSignTx, BinanceSignedTx
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__) @auto_keychain(__name__)
async def sign_tx( async def sign_tx(envelope: BinanceSignTx, keychain: Keychain) -> BinanceSignedTx:
ctx: Context, envelope: BinanceSignTx, keychain: Keychain
) -> BinanceSignedTx:
from trezor import wire from trezor import wire
from trezor.wire.context import call_any
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import MessageType from trezor.enums import MessageType
@ -32,12 +30,12 @@ async def sign_tx(
if envelope.msg_count > 1: if envelope.msg_count > 1:
raise wire.DataError("Multiple messages not supported.") raise wire.DataError("Multiple messages not supported.")
await paths.validate_path(ctx, keychain, envelope.address_n) await paths.validate_path(keychain, envelope.address_n)
node = keychain.derive(envelope.address_n) node = keychain.derive(envelope.address_n)
tx_req = BinanceTxRequest() tx_req = BinanceTxRequest()
msg = await ctx.call_any( msg = await call_any(
tx_req, tx_req,
MessageType.BinanceCancelMsg, MessageType.BinanceCancelMsg,
MessageType.BinanceOrderMsg, MessageType.BinanceOrderMsg,
@ -50,11 +48,11 @@ async def sign_tx(
msg_json = helpers.produce_json_for_signing(envelope, msg) msg_json = helpers.produce_json_for_signing(envelope, msg)
if BinanceTransferMsg.is_type_of(msg): if BinanceTransferMsg.is_type_of(msg):
await layout.require_confirm_transfer(ctx, msg) await layout.require_confirm_transfer(msg)
elif BinanceOrderMsg.is_type_of(msg): elif BinanceOrderMsg.is_type_of(msg):
await layout.require_confirm_order(ctx, msg) await layout.require_confirm_order(msg)
elif BinanceCancelMsg.is_type_of(msg): elif BinanceCancelMsg.is_type_of(msg):
await layout.require_confirm_cancel(ctx, msg) await layout.require_confirm_cancel(msg)
else: else:
raise wire.ProcessError("input message unrecognized") raise wire.ProcessError("input message unrecognized")

@ -8,7 +8,6 @@ if TYPE_CHECKING:
from trezor.messages import AuthorizeCoinJoin, Success from trezor.messages import AuthorizeCoinJoin, Success
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
_MAX_COORDINATOR_LEN = const(36) _MAX_COORDINATOR_LEN = const(36)
_MAX_ROUNDS = const(500) _MAX_ROUNDS = const(500)
@ -17,7 +16,7 @@ _MAX_COORDINATOR_FEE_RATE = 5 * pow(10, FEE_RATE_DECIMALS) # 5 %
@with_keychain @with_keychain
async def authorize_coinjoin( async def authorize_coinjoin(
ctx: Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
) -> Success: ) -> Success:
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.messages import Success from trezor.messages import Success
@ -64,11 +63,10 @@ async def authorize_coinjoin(
msg.max_fee_per_kvbyte / 1000, coin, include_shortcut=True msg.max_fee_per_kvbyte / 1000, coin, include_shortcut=True
) )
await confirm_coinjoin(ctx, msg.max_rounds, max_fee_per_vbyte) await confirm_coinjoin(msg.max_rounds, max_fee_per_vbyte)
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path( await validate_path(
ctx,
keychain, keychain,
validation_path, validation_path,
address_n[0] == SLIP25_PURPOSE, address_n[0] == SLIP25_PURPOSE,
@ -79,7 +77,6 @@ async def authorize_coinjoin(
if msg.max_fee_per_kvbyte > coin.maxfee_kb: if msg.max_fee_per_kvbyte > coin.maxfee_kb:
await confirm_metadata( await confirm_metadata(
ctx,
"fee_over_threshold", "fee_over_threshold",
"High mining fee", "High mining fee",
"The mining fee of\n{}\nis unexpectedly high.", "The mining fee of\n{}\nis unexpectedly high.",

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetAddress, HDNodeType, Address from trezor.messages import GetAddress, HDNodeType, Address
from trezor import wire
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -30,9 +29,7 @@ def _get_xpubs(
@with_keychain @with_keychain
async def get_address( async def get_address(msg: GetAddress, keychain: Keychain, coin: CoinInfo) -> Address:
ctx: wire.Context, msg: GetAddress, keychain: Keychain, coin: CoinInfo
) -> Address:
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import Address from trezor.messages import Address
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
@ -51,7 +48,6 @@ async def get_address(
if msg.show_display: if msg.show_display:
# skip soft-validation for silent calls # skip soft-validation for silent calls
await validate_path( await validate_path(
ctx,
keychain, keychain,
address_n, address_n,
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),
@ -104,7 +100,6 @@ async def get_address(
multisig_index = multisig_pubkey_index(multisig, node.public_key()) multisig_index = multisig_pubkey_index(multisig, node.public_key())
await show_address( await show_address(
ctx,
address_short, address_short,
case_sensitive=address_case_sensitive, case_sensitive=address_case_sensitive,
path=path, path=path,
@ -121,7 +116,6 @@ async def get_address(
else: else:
account = f"{coin.coin_shortcut} {account_name}" account = f"{coin.coin_shortcut} {account_name}"
await show_address( await show_address(
ctx,
address_short, address_short,
address_qr=address, address_qr=address,
case_sensitive=address_case_sensitive, case_sensitive=address_case_sensitive,

@ -4,14 +4,13 @@ from .keychain import with_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetOwnershipId, OwnershipId from trezor.messages import GetOwnershipId, OwnershipId
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@with_keychain @with_keychain
async def get_ownership_id( async def get_ownership_id(
ctx: Context, msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo
) -> OwnershipId: ) -> OwnershipId:
from trezor.wire import DataError from trezor.wire import DataError
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
@ -26,7 +25,6 @@ async def get_ownership_id(
script_type = msg.script_type # local_cache_attribute script_type = msg.script_type # local_cache_attribute
await validate_path( await validate_path(
ctx,
keychain, keychain,
msg.address_n, msg.address_n,
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetOwnershipProof, OwnershipProof from trezor.messages import GetOwnershipProof, OwnershipProof
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .authorization import CoinJoinAuthorization from .authorization import CoinJoinAuthorization
@ -12,7 +11,6 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def get_ownership_proof( async def get_ownership_proof(
ctx: Context,
msg: GetOwnershipProof, msg: GetOwnershipProof,
keychain: Keychain, keychain: Keychain,
coin: CoinInfo, coin: CoinInfo,
@ -37,7 +35,6 @@ async def get_ownership_proof(
raise ProcessError("Unauthorized operation") raise ProcessError("Unauthorized operation")
else: else:
await validate_path( await validate_path(
ctx,
keychain, keychain,
msg.address_n, msg.address_n,
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),
@ -71,14 +68,12 @@ async def get_ownership_proof(
# In order to set the "user confirmation" bit in the proof, the user must actually confirm. # In order to set the "user confirmation" bit in the proof, the user must actually confirm.
if msg.user_confirmation and not authorization: if msg.user_confirmation and not authorization:
await confirm_action( await confirm_action(
ctx,
"confirm_ownership_proof", "confirm_ownership_proof",
"Proof of ownership", "Proof of ownership",
description="Do you want to create a proof of ownership?", description="Do you want to create a proof of ownership?",
) )
if msg.commitment_data: if msg.commitment_data:
await confirm_blob( await confirm_blob(
ctx,
"confirm_ownership_proof", "confirm_ownership_proof",
"Proof of ownership", "Proof of ownership",
msg.commitment_data, msg.commitment_data,

@ -3,11 +3,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetPublicKey, PublicKey from trezor.messages import GetPublicKey, PublicKey
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
async def get_public_key( async def get_public_key(
ctx: Context, msg: GetPublicKey, auth_msg: MessageType | None = None msg: GetPublicKey, auth_msg: MessageType | None = None
) -> PublicKey: ) -> PublicKey:
from trezor import wire from trezor import wire
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
@ -33,7 +32,7 @@ async def get_public_key(
if auth_msg.address_n != address_n[: len(auth_msg.address_n)]: if auth_msg.address_n != address_n[: len(auth_msg.address_n)]:
raise FORBIDDEN_KEY_PATH raise FORBIDDEN_KEY_PATH
keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema]) keychain = await get_keychain(curve_name, [paths.AlwaysMatchingSchema])
node = keychain.derive(address_n) node = keychain.derive(address_n)
@ -82,7 +81,7 @@ async def get_public_key(
if msg.show_display: if msg.show_display:
from trezor.ui.layouts import show_xpub from trezor.ui.layouts import show_xpub
await show_xpub(ctx, node_xpub, "XPUB") await show_xpub(node_xpub, "XPUB")
return PublicKey( return PublicKey(
node=node_type, node=node_type,

@ -14,7 +14,6 @@ if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
GetAddress, GetAddress,
@ -265,7 +264,6 @@ def _get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
async def _get_keychain_for_coin( async def _get_keychain_for_coin(
ctx: Context,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
unlock_schemas: Iterable[PathSchema] = (), unlock_schemas: Iterable[PathSchema] = (),
) -> Keychain: ) -> Keychain:
@ -273,7 +271,7 @@ async def _get_keychain_for_coin(
schemas = _get_schemas_for_coin(coin, unlock_schemas) schemas = _get_schemas_for_coin(coin, unlock_schemas)
slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]] slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]]
keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces) keychain = await get_keychain(coin.curve_name, schemas, slip21_namespaces)
return keychain return keychain
@ -318,19 +316,18 @@ def _get_unlock_schemas(
def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper( async def wrapper(
ctx: Context,
msg: MsgIn, msg: MsgIn,
auth_msg: MessageType | None = None, auth_msg: MessageType | None = None,
) -> MsgOut: ) -> MsgOut:
coin = _get_coin_by_name(msg.coin_name) coin = _get_coin_by_name(msg.coin_name)
unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin) unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin)
keychain = await _get_keychain_for_coin(ctx, coin, unlock_schemas) keychain = await _get_keychain_for_coin(coin, unlock_schemas)
if AuthorizeCoinJoin.is_type_of(auth_msg): if AuthorizeCoinJoin.is_type_of(auth_msg):
auth_obj = authorization.from_cached_message(auth_msg) auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj) return await func(msg, keychain, coin, auth_obj)
else: else:
with keychain: with keychain:
return await func(ctx, msg, keychain, coin) return await func(msg, keychain, coin)
return wrapper return wrapper

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import SignMessage, MessageSignature from trezor.messages import SignMessage, MessageSignature
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -12,7 +11,7 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def sign_message( async def sign_message(
ctx: Context, msg: SignMessage, keychain: Keychain, coin: CoinInfo msg: SignMessage, keychain: Keychain, coin: CoinInfo
) -> MessageSignature: ) -> MessageSignature:
from trezor import wire from trezor import wire
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
@ -31,13 +30,12 @@ async def sign_message(
script_type = msg.script_type or InputScriptType.SPENDADDRESS script_type = msg.script_type or InputScriptType.SPENDADDRESS
await validate_path( await validate_path(
ctx, keychain, address_n, validate_path_against_script_type(coin, msg) keychain, address_n, validate_path_against_script_type(coin, msg)
) )
node = keychain.derive(address_n) node = keychain.derive(address_n)
address = get_address(script_type, coin, node) address = get_address(script_type, coin, node)
await confirm_signverify( await confirm_signverify(
ctx,
coin.coin_shortcut, coin.coin_shortcut,
decode_message(message), decode_message(message),
address_short(coin, address), address_short(coin, address),

@ -11,7 +11,6 @@ if not utils.BITCOIN_ONLY:
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Protocol from typing import Protocol
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
SignTx, SignTx,
TxAckInput, TxAckInput,
@ -54,7 +53,6 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def sign_tx( async def sign_tx(
ctx: Context,
msg: SignTx, msg: SignTx,
keychain: Keychain, keychain: Keychain,
coin: CoinInfo, coin: CoinInfo,
@ -62,6 +60,7 @@ async def sign_tx(
) -> TxRequest: ) -> TxRequest:
from trezor.enums import RequestType from trezor.enums import RequestType
from trezor.messages import TxRequest from trezor.messages import TxRequest
from trezor.wire.context import call
from ..common import BITCOIN_NAMES from ..common import BITCOIN_NAMES
from . import approvers, bitcoin, helpers, progress from . import approvers, bitcoin, helpers, progress
@ -93,9 +92,9 @@ async def sign_tx(
assert TxRequest.is_type_of(req) assert TxRequest.is_type_of(req)
if req.request_type == RequestType.TXFINISHED: if req.request_type == RequestType.TXFINISHED:
return req return req
res = await ctx.call(req, request_class) res = await call(req, request_class)
elif isinstance(req, helpers.UiConfirm): elif isinstance(req, helpers.UiConfirm):
res = await req.confirm_dialog(ctx) res = await req.confirm_dialog()
progress.progress.report_init() progress.progress.report_init()
else: else:
raise TypeError("Invalid signing instruction") raise TypeError("Invalid signing instruction")

@ -11,7 +11,6 @@ from . import layout
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Awaitable from typing import Any, Awaitable
from trezor.enums import AmountUnit from trezor.enums import AmountUnit
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
PrevInput, PrevInput,
@ -31,7 +30,7 @@ if TYPE_CHECKING:
class UiConfirm: class UiConfirm:
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
raise NotImplementedError raise NotImplementedError
__eq__ = utils.obj_eq __eq__ = utils.obj_eq
@ -50,9 +49,8 @@ class UiConfirmOutput(UiConfirm):
self.amount_unit = amount_unit self.amount_unit = amount_unit
self.output_index = output_index self.output_index = output_index
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_output( return layout.confirm_output(
ctx,
self.output, self.output,
self.coin, self.coin,
self.amount_unit, self.amount_unit,
@ -66,9 +64,9 @@ class UiConfirmDecredSSTXSubmission(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_decred_sstx_submission( return layout.confirm_decred_sstx_submission(
ctx, self.output, self.coin, self.amount_unit self.output, self.coin, self.amount_unit
) )
@ -83,9 +81,9 @@ class UiConfirmPaymentRequest(UiConfirm):
self.amount_unit = amount_unit self.amount_unit = amount_unit
self.coin = coin self.coin = coin
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_payment_request( return layout.confirm_payment_request(
ctx, self.payment_req, self.coin, self.amount_unit self.payment_req, self.coin, self.amount_unit
) )
__eq__ = utils.obj_eq __eq__ = utils.obj_eq
@ -96,8 +94,8 @@ class UiConfirmReplacement(UiConfirm):
self.title = title self.title = title
self.txid = txid self.txid = txid
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_replacement(ctx, self.title, self.txid) return layout.confirm_replacement(self.title, self.txid)
class UiConfirmModifyOutput(UiConfirm): class UiConfirmModifyOutput(UiConfirm):
@ -113,9 +111,9 @@ class UiConfirmModifyOutput(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_modify_output( return layout.confirm_modify_output(
ctx, self.txo, self.orig_txo, self.coin, self.amount_unit self.txo, self.orig_txo, self.coin, self.amount_unit
) )
@ -136,9 +134,8 @@ class UiConfirmModifyFee(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_modify_fee( return layout.confirm_modify_fee(
ctx,
self.title, self.title,
self.user_fee_change, self.user_fee_change,
self.total_fee_new, self.total_fee_new,
@ -165,9 +162,8 @@ class UiConfirmTotal(UiConfirm):
self.amount_unit = amount_unit self.amount_unit = amount_unit
self.address_n = address_n self.address_n = address_n
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_total( return layout.confirm_total(
ctx,
self.spending, self.spending,
self.fee, self.fee,
self.fee_rate, self.fee_rate,
@ -186,9 +182,9 @@ class UiConfirmJointTotal(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_joint_total( return layout.confirm_joint_total(
ctx, self.spending, self.total, self.coin, self.amount_unit self.spending, self.total, self.coin, self.amount_unit
) )
@ -198,33 +194,31 @@ class UiConfirmFeeOverThreshold(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_feeoverthreshold( return layout.confirm_feeoverthreshold(self.fee, self.coin, self.amount_unit)
ctx, self.fee, self.coin, self.amount_unit
)
class UiConfirmChangeCountOverThreshold(UiConfirm): class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int): def __init__(self, change_count: int):
self.change_count = change_count self.change_count = change_count
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_change_count_over_threshold(ctx, self.change_count) return layout.confirm_change_count_over_threshold(self.change_count)
class UiConfirmUnverifiedExternalInput(UiConfirm): class UiConfirmUnverifiedExternalInput(UiConfirm):
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_unverified_external_input(ctx) return layout.confirm_unverified_external_input()
class UiConfirmForeignAddress(UiConfirm): class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list): def __init__(self, address_n: list):
self.address_n = address_n self.address_n = address_n
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
from apps.common import paths from apps.common import paths
return paths.show_path_warning(ctx, self.address_n) return paths.show_path_warning(self.address_n)
class UiConfirmNonDefaultLocktime(UiConfirm): class UiConfirmNonDefaultLocktime(UiConfirm):
@ -232,9 +226,9 @@ class UiConfirmNonDefaultLocktime(UiConfirm):
self.lock_time = lock_time self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled self.lock_time_disabled = lock_time_disabled
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]: def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_nondefault_locktime( return layout.confirm_nondefault_locktime(
ctx, self.lock_time, self.lock_time_disabled self.lock_time, self.lock_time_disabled
) )

@ -22,7 +22,6 @@ if TYPE_CHECKING:
from trezor.messages import TxAckPaymentRequest, TxOutput from trezor.messages import TxAckPaymentRequest, TxOutput
from trezor.ui.layouts import LayoutType from trezor.ui.layouts import LayoutType
from trezor.enums import AmountUnit from trezor.enums import AmountUnit
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
@ -59,7 +58,6 @@ def account_label(coin: CoinInfo, address_n: Bip32Path | None) -> str:
async def confirm_output( async def confirm_output(
ctx: Context,
output: TxOutput, output: TxOutput,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
@ -74,7 +72,6 @@ async def confirm_output(
if omni.is_valid(data): if omni.is_valid(data):
# OMNI transaction # OMNI transaction
layout: LayoutType = confirm_metadata( layout: LayoutType = confirm_metadata(
ctx,
"omni_transaction", "omni_transaction",
"OMNI transaction", "OMNI transaction",
omni.parse(data), omni.parse(data),
@ -84,7 +81,6 @@ async def confirm_output(
else: else:
# generic OP_RETURN # generic OP_RETURN
layout = layouts.confirm_blob( layout = layouts.confirm_blob(
ctx,
"op_return", "op_return",
"OP_RETURN", "OP_RETURN",
data, data,
@ -107,7 +103,6 @@ async def confirm_output(
) )
layout = layouts.confirm_output( layout = layouts.confirm_output(
ctx,
address_short, address_short,
format_coin_amount(output.amount, coin, amount_unit), format_coin_amount(output.amount, coin, amount_unit),
title=title, title=title,
@ -119,14 +114,13 @@ async def confirm_output(
async def confirm_decred_sstx_submission( async def confirm_decred_sstx_submission(
ctx: Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None: ) -> None:
assert output.address is not None assert output.address is not None
address_short = addresses.address_short(coin, output.address) address_short = addresses.address_short(coin, output.address)
amount = format_coin_amount(output.amount, coin, amount_unit) amount = format_coin_amount(output.amount, coin, amount_unit)
await layouts.confirm_value( await layouts.confirm_value(
ctx,
"Purchase ticket", "Purchase ticket",
amount, amount,
"Ticket amount:", "Ticket amount:",
@ -136,7 +130,6 @@ async def confirm_decred_sstx_submission(
) )
await layouts.confirm_value( await layouts.confirm_value(
ctx,
"Purchase ticket", "Purchase ticket",
address_short, address_short,
"Voting rights to:", "Voting rights to:",
@ -147,7 +140,6 @@ async def confirm_decred_sstx_submission(
async def confirm_payment_request( async def confirm_payment_request(
ctx: Context,
msg: TxAckPaymentRequest, msg: TxAckPaymentRequest,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
@ -168,25 +160,22 @@ async def confirm_payment_request(
assert msg.amount is not None assert msg.amount is not None
return await layouts.confirm_payment_request( return await layouts.confirm_payment_request(
ctx,
msg.recipient_name, msg.recipient_name,
format_coin_amount(msg.amount, coin, amount_unit), format_coin_amount(msg.amount, coin, amount_unit),
memo_texts, memo_texts,
) )
async def confirm_replacement(ctx: Context, title: str, txid: bytes) -> None: async def confirm_replacement(title: str, txid: bytes) -> None:
from ubinascii import hexlify from ubinascii import hexlify
await layouts.confirm_replacement( await layouts.confirm_replacement(
ctx,
title, title,
hexlify(txid).decode(), hexlify(txid).decode(),
) )
async def confirm_modify_output( async def confirm_modify_output(
ctx: Context,
txo: TxOutput, txo: TxOutput,
orig_txo: TxOutput, orig_txo: TxOutput,
coin: CoinInfo, coin: CoinInfo,
@ -196,7 +185,6 @@ async def confirm_modify_output(
address_short = addresses.address_short(coin, txo.address) address_short = addresses.address_short(coin, txo.address)
amount_change = txo.amount - orig_txo.amount amount_change = txo.amount - orig_txo.amount
await layouts.confirm_modify_output( await layouts.confirm_modify_output(
ctx,
address_short, address_short,
amount_change, amount_change,
format_coin_amount(abs(amount_change), coin, amount_unit), format_coin_amount(abs(amount_change), coin, amount_unit),
@ -205,7 +193,6 @@ async def confirm_modify_output(
async def confirm_modify_fee( async def confirm_modify_fee(
ctx: Context,
title: str, title: str,
user_fee_change: int, user_fee_change: int,
total_fee_new: int, total_fee_new: int,
@ -214,7 +201,6 @@ async def confirm_modify_fee(
amount_unit: AmountUnit, amount_unit: AmountUnit,
) -> None: ) -> None:
await layouts.confirm_modify_fee( await layouts.confirm_modify_fee(
ctx,
title, title,
user_fee_change, user_fee_change,
format_coin_amount(abs(user_fee_change), coin, amount_unit), format_coin_amount(abs(user_fee_change), coin, amount_unit),
@ -224,21 +210,18 @@ async def confirm_modify_fee(
async def confirm_joint_total( async def confirm_joint_total(
ctx: Context,
spending: int, spending: int,
total: int, total: int,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
) -> None: ) -> None:
await layouts.confirm_joint_total( await layouts.confirm_joint_total(
ctx,
spending_amount=format_coin_amount(spending, coin, amount_unit), spending_amount=format_coin_amount(spending, coin, amount_unit),
total_amount=format_coin_amount(total, coin, amount_unit), total_amount=format_coin_amount(total, coin, amount_unit),
) )
async def confirm_total( async def confirm_total(
ctx: Context,
spending: int, spending: int,
fee: int, fee: int,
fee_rate: float, fee_rate: float,
@ -248,7 +231,6 @@ async def confirm_total(
) -> None: ) -> None:
await layouts.confirm_total( await layouts.confirm_total(
ctx,
format_coin_amount(spending, coin, amount_unit), format_coin_amount(spending, coin, amount_unit),
format_coin_amount(fee, coin, amount_unit), format_coin_amount(fee, coin, amount_unit),
fee_rate_amount=format_fee_rate(fee_rate, coin) if fee_rate >= 0 else None, fee_rate_amount=format_fee_rate(fee_rate, coin) if fee_rate >= 0 else None,
@ -257,11 +239,10 @@ async def confirm_total(
async def confirm_feeoverthreshold( async def confirm_feeoverthreshold(
ctx: Context, fee: int, coin: CoinInfo, amount_unit: AmountUnit fee: int, coin: CoinInfo, amount_unit: AmountUnit
) -> None: ) -> None:
fee_amount = format_coin_amount(fee, coin, amount_unit) fee_amount = format_coin_amount(fee, coin, amount_unit)
await layouts.show_warning( await layouts.show_warning(
ctx,
"fee_over_threshold", "fee_over_threshold",
"Unusually high fee.", "Unusually high fee.",
fee_amount, fee_amount,
@ -269,9 +250,8 @@ async def confirm_feeoverthreshold(
) )
async def confirm_change_count_over_threshold(ctx: Context, change_count: int) -> None: async def confirm_change_count_over_threshold(change_count: int) -> None:
await layouts.show_warning( await layouts.show_warning(
ctx,
"change_count_over_threshold", "change_count_over_threshold",
"A lot of change-outputs.", "A lot of change-outputs.",
f"{str(change_count)} outputs", f"{str(change_count)} outputs",
@ -279,9 +259,8 @@ async def confirm_change_count_over_threshold(ctx: Context, change_count: int) -
) )
async def confirm_unverified_external_input(ctx: Context) -> None: async def confirm_unverified_external_input() -> None:
await layouts.show_warning( await layouts.show_warning(
ctx,
"unverified_external_input", "unverified_external_input",
"The transaction contains unverified external inputs.", "The transaction contains unverified external inputs.",
"Proceed anyway?", "Proceed anyway?",
@ -290,14 +269,11 @@ async def confirm_unverified_external_input(ctx: Context) -> None:
) )
async def confirm_nondefault_locktime( async def confirm_nondefault_locktime(lock_time: int, lock_time_disabled: bool) -> None:
ctx: Context, lock_time: int, lock_time_disabled: bool
) -> None:
from trezor.strings import format_timestamp from trezor.strings import format_timestamp
if lock_time_disabled: if lock_time_disabled:
await layouts.show_warning( await layouts.show_warning(
ctx,
"nondefault_locktime", "nondefault_locktime",
"Locktime is set but will have no effect.", "Locktime is set but will have no effect.",
"Proceed anyway?", "Proceed anyway?",
@ -312,7 +288,6 @@ async def confirm_nondefault_locktime(
text = "Locktime for this transaction is set to:" text = "Locktime for this transaction is set to:"
value = format_timestamp(lock_time) value = format_timestamp(lock_time)
await layouts.confirm_value( await layouts.confirm_value(
ctx,
"Confirm locktime", "Confirm locktime",
value, value,
text, text,

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from trezor.messages import VerifyMessage, Success from trezor.messages import VerifyMessage, Success
from trezor.wire import Context
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
@ -48,7 +47,7 @@ def _address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
raise DataError("Invalid address") raise DataError("Invalid address")
async def verify_message(ctx: Context, msg: VerifyMessage) -> Success: async def verify_message(msg: VerifyMessage) -> Success:
from trezor import utils from trezor import utils
from trezor.wire import ProcessError from trezor.wire import ProcessError
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
@ -109,12 +108,11 @@ async def verify_message(ctx: Context, msg: VerifyMessage) -> Success:
raise ProcessError("Invalid signature") raise ProcessError("Invalid signature")
await confirm_signverify( await confirm_signverify(
ctx,
coin.coin_shortcut, coin.coin_shortcut,
decode_message(message), decode_message(message),
address_short(coin, address), address_short(coin, address),
verify=True, verify=True,
) )
await show_success(ctx, "verify_message", "The signature is valid.") await show_success("verify_message", "The signature is valid.")
return Success(message="Message verified") return Success(message="Message verified")

@ -16,7 +16,6 @@ if TYPE_CHECKING:
SignedCVoteRegistrationPayload = tuple[CVoteRegistrationPayload, bytes] SignedCVoteRegistrationPayload = tuple[CVoteRegistrationPayload, bytes]
from trezor import messages from trezor import messages
from trezor.wire import Context
from . import seed from . import seed
@ -115,7 +114,6 @@ def _get_voting_purpose_to_serialize(
async def show( async def show(
ctx: Context,
keychain: seed.Keychain, keychain: seed.Keychain,
auxiliary_data_hash: bytes, auxiliary_data_hash: bytes,
parameters: messages.CardanoCVoteRegistrationParametersType | None, parameters: messages.CardanoCVoteRegistrationParametersType | None,
@ -125,7 +123,6 @@ async def show(
) -> None: ) -> None:
if parameters: if parameters:
await _show_cvote_registration( await _show_cvote_registration(
ctx,
keychain, keychain,
parameters, parameters,
protocol_magic, protocol_magic,
@ -134,7 +131,7 @@ async def show(
) )
if should_show_details: if should_show_details:
await layout.show_auxiliary_data_hash(ctx, auxiliary_data_hash) await layout.show_auxiliary_data_hash(auxiliary_data_hash)
def _should_show_payment_warning(address_type: CardanoAddressType) -> bool: def _should_show_payment_warning(address_type: CardanoAddressType) -> bool:
@ -146,7 +143,6 @@ def _should_show_payment_warning(address_type: CardanoAddressType) -> bool:
async def _show_cvote_registration( async def _show_cvote_registration(
ctx: Context,
keychain: seed.Keychain, keychain: seed.Keychain,
parameters: messages.CardanoCVoteRegistrationParametersType, parameters: messages.CardanoCVoteRegistrationParametersType,
protocol_magic: int, protocol_magic: int,
@ -161,7 +157,7 @@ async def _show_cvote_registration(
bech32.HRP_CVOTE_PUBLIC_KEY, delegation.vote_public_key bech32.HRP_CVOTE_PUBLIC_KEY, delegation.vote_public_key
) )
await layout.confirm_cvote_registration_delegation( await layout.confirm_cvote_registration_delegation(
ctx, encoded_public_key, delegation.weight encoded_public_key, delegation.weight
) )
if parameters.payment_address: if parameters.payment_address:
@ -169,7 +165,7 @@ async def _show_cvote_registration(
addresses.get_type(addresses.get_bytes_unsafe(parameters.payment_address)) addresses.get_type(addresses.get_bytes_unsafe(parameters.payment_address))
) )
await layout.confirm_cvote_registration_payment_address( await layout.confirm_cvote_registration_payment_address(
ctx, parameters.payment_address, show_payment_warning parameters.payment_address, show_payment_warning
) )
else: else:
address_parameters = parameters.payment_address_parameters address_parameters = parameters.payment_address_parameters
@ -179,7 +175,6 @@ async def _show_cvote_registration(
address_parameters.address_type address_parameters.address_type
) )
await layout.show_cvote_registration_payment_credentials( await layout.show_cvote_registration_payment_credentials(
ctx,
Credential.payment_credential(address_parameters), Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
show_both_credentials, show_both_credentials,
@ -197,7 +192,6 @@ async def _show_cvote_registration(
) )
await layout.confirm_cvote_registration( await layout.confirm_cvote_registration(
ctx,
encoded_public_key, encoded_public_key,
parameters.staking_path, parameters.staking_path,
parameters.nonce, parameters.nonce,

@ -3,13 +3,12 @@ from typing import TYPE_CHECKING
from . import seed from . import seed
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetAddress, CardanoAddress from trezor.messages import CardanoGetAddress, CardanoAddress
@seed.with_keychain @seed.with_keychain
async def get_address( async def get_address(
ctx: Context, msg: CardanoGetAddress, keychain: seed.Keychain msg: CardanoGetAddress, keychain: seed.Keychain
) -> CardanoAddress: ) -> CardanoAddress:
from trezor.messages import CardanoAddress from trezor.messages import CardanoAddress
from trezor import log, wire from trezor import log, wire
@ -36,10 +35,9 @@ async def get_address(
# _display_address # _display_address
if should_show_credentials(address_parameters): if should_show_credentials(address_parameters):
await show_credentials( await show_credentials(
ctx,
Credential.payment_credential(address_parameters), Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
) )
await show_cardano_address(ctx, address_parameters, address, msg.protocol_magic) await show_cardano_address(address_parameters, address, msg.protocol_magic)
return CardanoAddress(address=address) return CardanoAddress(address=address)

@ -3,13 +3,12 @@ from typing import TYPE_CHECKING
from . import seed from . import seed
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetNativeScriptHash, CardanoNativeScriptHash from trezor.messages import CardanoGetNativeScriptHash, CardanoNativeScriptHash
@seed.with_keychain @seed.with_keychain
async def get_native_script_hash( async def get_native_script_hash(
ctx: Context, msg: CardanoGetNativeScriptHash, keychain: seed.Keychain msg: CardanoGetNativeScriptHash, keychain: seed.Keychain
) -> CardanoNativeScriptHash: ) -> CardanoNativeScriptHash:
from trezor.messages import CardanoNativeScriptHash from trezor.messages import CardanoNativeScriptHash
from trezor.enums import CardanoNativeScriptHashDisplayFormat from trezor.enums import CardanoNativeScriptHashDisplayFormat
@ -20,7 +19,7 @@ async def get_native_script_hash(
script_hash = native_script.get_native_script_hash(keychain, msg.script) script_hash = native_script.get_native_script_hash(keychain, msg.script)
if msg.display_format != CardanoNativeScriptHashDisplayFormat.HIDE: if msg.display_format != CardanoNativeScriptHashDisplayFormat.HIDE:
await layout.show_native_script(ctx, msg.script) await layout.show_native_script(msg.script)
await layout.show_script_hash(ctx, script_hash, msg.display_format) await layout.show_script_hash(script_hash, msg.display_format)
return CardanoNativeScriptHash(script_hash=script_hash) return CardanoNativeScriptHash(script_hash=script_hash)

@ -4,13 +4,12 @@ from ubinascii import hexlify
from . import seed from . import seed
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetPublicKey, CardanoPublicKey from trezor.messages import CardanoGetPublicKey, CardanoPublicKey
@seed.with_keychain @seed.with_keychain
async def get_public_key( async def get_public_key(
ctx: Context, msg: CardanoGetPublicKey, keychain: seed.Keychain msg: CardanoGetPublicKey, keychain: seed.Keychain
) -> CardanoPublicKey: ) -> CardanoPublicKey:
from trezor import log, wire from trezor import log, wire
from trezor.ui.layouts import show_pubkey from trezor.ui.layouts import show_pubkey
@ -20,7 +19,6 @@ async def get_public_key(
address_n = msg.address_n # local_cache_attribute address_n = msg.address_n # local_cache_attribute
await paths.validate_path( await paths.validate_path(
ctx,
keychain, keychain,
address_n, address_n,
# path must match the PUBKEY schema # path must match the PUBKEY schema
@ -35,7 +33,7 @@ async def get_public_key(
raise wire.ProcessError("Deriving public key failed") raise wire.ProcessError("Deriving public key failed")
if msg.show_display: if msg.show_display:
await show_pubkey(ctx, hexlify(key.node.public_key).decode()) await show_pubkey(hexlify(key.node.public_key).decode())
return key return key

@ -25,7 +25,6 @@ from .helpers.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Literal from typing import Literal
from trezor.wire import Context
from trezor import messages from trezor import messages
from trezor.enums import CardanoNativeScriptHashDisplayFormat from trezor.enums import CardanoNativeScriptHashDisplayFormat
from trezor.ui.layouts import PropertyType from trezor.ui.layouts import PropertyType
@ -79,7 +78,6 @@ def format_coin_amount(amount: int, network_id: int) -> str:
async def show_native_script( async def show_native_script(
ctx: Context,
script: messages.CardanoNativeScript, script: messages.CardanoNativeScript,
indices: list[int] | None = None, indices: list[int] | None = None,
) -> None: ) -> None:
@ -140,7 +138,6 @@ async def show_native_script(
append((f"Contains {len(scripts)} nested scripts.", None)) append((f"Contains {len(scripts)} nested scripts.", None))
await confirm_properties( await confirm_properties(
ctx,
"verify_script", "verify_script",
"Verify script", "Verify script",
props, props,
@ -148,11 +145,10 @@ async def show_native_script(
) )
for i, sub_script in enumerate(scripts): for i, sub_script in enumerate(scripts):
await show_native_script(ctx, sub_script, indices + [i + 1]) await show_native_script(sub_script, indices + [i + 1])
async def show_script_hash( async def show_script_hash(
ctx: Context,
script_hash: bytes, script_hash: bytes,
display_format: CardanoNativeScriptHashDisplayFormat, display_format: CardanoNativeScriptHashDisplayFormat,
) -> None: ) -> None:
@ -165,7 +161,6 @@ async def show_script_hash(
if display_format == CardanoNativeScriptHashDisplayFormat.BECH32: if display_format == CardanoNativeScriptHashDisplayFormat.BECH32:
await confirm_properties( await confirm_properties(
ctx,
"verify_script", "verify_script",
"Verify script", "Verify script",
(("Script hash:", bech32.encode(bech32.HRP_SCRIPT_HASH, script_hash)),), (("Script hash:", bech32.encode(bech32.HRP_SCRIPT_HASH, script_hash)),),
@ -173,7 +168,6 @@ async def show_script_hash(
) )
elif display_format == CardanoNativeScriptHashDisplayFormat.POLICY_ID: elif display_format == CardanoNativeScriptHashDisplayFormat.POLICY_ID:
await layouts.confirm_blob( await layouts.confirm_blob(
ctx,
"verify_script", "verify_script",
"Verify script", "Verify script",
script_hash, script_hash,
@ -182,9 +176,8 @@ async def show_script_hash(
) )
async def show_tx_init(ctx: Context, title: str) -> bool: async def show_tx_init(title: str) -> bool:
should_show_details = await layouts.should_show_more( should_show_details = await layouts.should_show_more(
ctx,
"Confirm transaction", "Confirm transaction",
( (
( (
@ -200,9 +193,8 @@ async def show_tx_init(ctx: Context, title: str) -> bool:
return should_show_details return should_show_details
async def confirm_input(ctx: Context, input: messages.CardanoTxInput) -> None: async def confirm_input(input: messages.CardanoTxInput) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_input", "confirm_input",
"Confirm transaction", "Confirm transaction",
( (
@ -214,7 +206,6 @@ async def confirm_input(ctx: Context, input: messages.CardanoTxInput) -> None:
async def confirm_sending( async def confirm_sending(
ctx: Context,
ada_amount: int, ada_amount: int,
to: str, to: str,
output_type: Literal["address", "change", "collateral-return"], output_type: Literal["address", "change", "collateral-return"],
@ -230,7 +221,6 @@ async def confirm_sending(
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
await layouts.confirm_output( await layouts.confirm_output(
ctx,
to, to,
format_coin_amount(ada_amount, network_id), format_coin_amount(ada_amount, network_id),
title, title,
@ -238,13 +228,10 @@ async def confirm_sending(
) )
async def confirm_sending_token( async def confirm_sending_token(policy_id: bytes, token: messages.CardanoToken) -> None:
ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None:
assert token.amount is not None # _validate_token assert token.amount is not None # _validate_token
await confirm_properties( await confirm_properties(
ctx,
"confirm_token", "confirm_token",
"Confirm transaction", "Confirm transaction",
( (
@ -261,9 +248,8 @@ async def confirm_sending_token(
) )
async def confirm_datum_hash(ctx: Context, datum_hash: bytes) -> None: async def confirm_datum_hash(datum_hash: bytes) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_datum_hash", "confirm_datum_hash",
"Confirm transaction", "Confirm transaction",
( (
@ -276,11 +262,8 @@ async def confirm_datum_hash(ctx: Context, datum_hash: bytes) -> None:
) )
async def confirm_inline_datum( async def confirm_inline_datum(first_chunk: bytes, inline_datum_size: int) -> None:
ctx: Context, first_chunk: bytes, inline_datum_size: int
) -> None:
await _confirm_data_chunk( await _confirm_data_chunk(
ctx,
"confirm_inline_datum", "confirm_inline_datum",
"Inline datum", "Inline datum",
first_chunk, first_chunk,
@ -289,10 +272,9 @@ async def confirm_inline_datum(
async def confirm_reference_script( async def confirm_reference_script(
ctx: Context, first_chunk: bytes, reference_script_size: int first_chunk: bytes, reference_script_size: int
) -> None: ) -> None:
await _confirm_data_chunk( await _confirm_data_chunk(
ctx,
"confirm_reference_script", "confirm_reference_script",
"Reference script", "Reference script",
first_chunk, first_chunk,
@ -301,7 +283,7 @@ async def confirm_reference_script(
async def _confirm_data_chunk( async def _confirm_data_chunk(
ctx: Context, br_type: str, title: str, first_chunk: bytes, data_size: int br_type: str, title: str, first_chunk: bytes, data_size: int
) -> None: ) -> None:
MAX_DISPLAYED_SIZE = 56 MAX_DISPLAYED_SIZE = 56
displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE] displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE]
@ -315,7 +297,6 @@ async def _confirm_data_chunk(
if data_size > MAX_DISPLAYED_SIZE: if data_size > MAX_DISPLAYED_SIZE:
props.append(("...", None)) props.append(("...", None))
await confirm_properties( await confirm_properties(
ctx,
br_type, br_type,
title="Confirm transaction", title="Confirm transaction",
props=props, props=props,
@ -324,39 +305,35 @@ async def _confirm_data_chunk(
async def show_credentials( async def show_credentials(
ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
) -> None: ) -> None:
intro_text = "Address" intro_text = "Address"
await _show_credential(ctx, payment_credential, intro_text, purpose="address") await _show_credential(payment_credential, intro_text, purpose="address")
await _show_credential(ctx, stake_credential, intro_text, purpose="address") await _show_credential(stake_credential, intro_text, purpose="address")
async def show_change_output_credentials( async def show_change_output_credentials(
ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
) -> None: ) -> None:
intro_text = "The following address is a change address. Its" intro_text = "The following address is a change address. Its"
await _show_credential(ctx, payment_credential, intro_text, purpose="output") await _show_credential(payment_credential, intro_text, purpose="output")
await _show_credential(ctx, stake_credential, intro_text, purpose="output") await _show_credential(stake_credential, intro_text, purpose="output")
async def show_device_owned_output_credentials( async def show_device_owned_output_credentials(
ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
show_both_credentials: bool, show_both_credentials: bool,
) -> None: ) -> None:
intro_text = "The following address is owned by this device. Its" intro_text = "The following address is owned by this device. Its"
await _show_credential(ctx, payment_credential, intro_text, purpose="output") await _show_credential(payment_credential, intro_text, purpose="output")
if show_both_credentials: if show_both_credentials:
await _show_credential(ctx, stake_credential, intro_text, purpose="output") await _show_credential(stake_credential, intro_text, purpose="output")
async def show_cvote_registration_payment_credentials( async def show_cvote_registration_payment_credentials(
ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
show_both_credentials: bool, show_both_credentials: bool,
@ -366,12 +343,11 @@ async def show_cvote_registration_payment_credentials(
"The vote key registration payment address is owned by this device. Its" "The vote key registration payment address is owned by this device. Its"
) )
await _show_credential( await _show_credential(
ctx, payment_credential, intro_text, purpose="cvote_reg_payment_address" payment_credential, intro_text, purpose="cvote_reg_payment_address"
) )
if show_both_credentials or show_payment_warning: if show_both_credentials or show_payment_warning:
extra_text = CVOTE_REWARD_ELIGIBILITY_WARNING if show_payment_warning else None extra_text = CVOTE_REWARD_ELIGIBILITY_WARNING if show_payment_warning else None
await _show_credential( await _show_credential(
ctx,
stake_credential, stake_credential,
intro_text, intro_text,
purpose="cvote_reg_payment_address", purpose="cvote_reg_payment_address",
@ -380,7 +356,6 @@ async def show_cvote_registration_payment_credentials(
async def _show_credential( async def _show_credential(
ctx: Context,
credential: Credential, credential: Credential,
intro_text: str, intro_text: str,
purpose: Literal["address", "output", "cvote_reg_payment_address"], purpose: Literal["address", "output", "cvote_reg_payment_address"],
@ -428,7 +403,6 @@ async def _show_credential(
if len(props) > 0: if len(props) > 0:
await confirm_properties( await confirm_properties(
ctx,
"confirm_credential", "confirm_credential",
title, title,
props, props,
@ -436,20 +410,17 @@ async def _show_credential(
) )
async def warn_path(ctx: Context, path: list[int], title: str) -> None: async def warn_path(path: list[int], title: str) -> None:
await layouts.confirm_path_warning(ctx, address_n_to_str(path), path_type=title) await layouts.confirm_path_warning(address_n_to_str(path), path_type=title)
async def warn_tx_output_contains_tokens( async def warn_tx_output_contains_tokens(is_collateral_return: bool = False) -> None:
ctx: Context, is_collateral_return: bool = False
) -> None:
content = ( content = (
"The collateral return output contains tokens." "The collateral return output contains tokens."
if is_collateral_return if is_collateral_return
else "The following transaction output contains tokens." else "The following transaction output contains tokens."
) )
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_tokens", "confirm_tokens",
"Confirm transaction", "Confirm transaction",
content, content,
@ -457,9 +428,8 @@ async def warn_tx_output_contains_tokens(
) )
async def warn_tx_contains_mint(ctx: Context) -> None: async def warn_tx_contains_mint() -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_tokens", "confirm_tokens",
"Confirm transaction", "Confirm transaction",
"The transaction contains minting or burning of tokens.", "The transaction contains minting or burning of tokens.",
@ -467,9 +437,8 @@ async def warn_tx_contains_mint(ctx: Context) -> None:
) )
async def warn_tx_output_no_datum(ctx: Context) -> None: async def warn_tx_output_no_datum() -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_no_datum_hash", "confirm_no_datum_hash",
"Confirm transaction", "Confirm transaction",
"The following transaction output contains a script address, but does not contain a datum.", "The following transaction output contains a script address, but does not contain a datum.",
@ -477,9 +446,8 @@ async def warn_tx_output_no_datum(ctx: Context) -> None:
) )
async def warn_no_script_data_hash(ctx: Context) -> None: async def warn_no_script_data_hash() -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_no_script_data_hash", "confirm_no_script_data_hash",
"Confirm transaction", "Confirm transaction",
"The transaction contains no script data hash. Plutus script will not be able to run.", "The transaction contains no script data hash. Plutus script will not be able to run.",
@ -487,9 +455,8 @@ async def warn_no_script_data_hash(ctx: Context) -> None:
) )
async def warn_no_collateral_inputs(ctx: Context) -> None: async def warn_no_collateral_inputs() -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_no_collateral_inputs", "confirm_no_collateral_inputs",
"Confirm transaction", "Confirm transaction",
"The transaction contains no collateral inputs. Plutus script will not be able to run.", "The transaction contains no collateral inputs. Plutus script will not be able to run.",
@ -497,9 +464,8 @@ async def warn_no_collateral_inputs(ctx: Context) -> None:
) )
async def warn_unknown_total_collateral(ctx: Context) -> None: async def warn_unknown_total_collateral() -> None:
await layouts.show_warning( await layouts.show_warning(
ctx,
"confirm_unknown_total_collateral", "confirm_unknown_total_collateral",
"Unknown collateral amount.", "Unknown collateral amount.",
"Check all items carefully.", "Check all items carefully.",
@ -508,7 +474,6 @@ async def warn_unknown_total_collateral(ctx: Context) -> None:
async def confirm_witness_request( async def confirm_witness_request(
ctx: Context,
witness_path: list[int], witness_path: list[int],
) -> None: ) -> None:
from . import seed from . import seed
@ -521,7 +486,6 @@ async def confirm_witness_request(
path_title = "path" path_title = "path"
await layouts.confirm_text( await layouts.confirm_text(
ctx,
"confirm_total", "confirm_total",
"Confirm transaction", "Confirm transaction",
address_n_to_str(witness_path), address_n_to_str(witness_path),
@ -531,7 +495,6 @@ async def confirm_witness_request(
async def confirm_tx( async def confirm_tx(
ctx: Context,
fee: int, fee: int,
network_id: int, network_id: int,
protocol_magic: int, protocol_magic: int,
@ -559,7 +522,6 @@ async def confirm_tx(
append(("Transaction ID:", tx_hash)) append(("Transaction ID:", tx_hash))
await confirm_properties( await confirm_properties(
ctx,
"confirm_total", "confirm_total",
"Confirm transaction", "Confirm transaction",
props, props,
@ -568,9 +530,7 @@ async def confirm_tx(
) )
async def confirm_certificate( async def confirm_certificate(certificate: messages.CardanoTxCertificate) -> None:
ctx: Context, certificate: messages.CardanoTxCertificate
) -> None:
# stake pool registration requires custom confirmation logic not covered # stake pool registration requires custom confirmation logic not covered
# in this call # in this call
assert certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION assert certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION
@ -587,7 +547,6 @@ async def confirm_certificate(
props.append(("to pool:", format_stake_pool_id(certificate.pool))) props.append(("to pool:", format_stake_pool_id(certificate.pool)))
await confirm_properties( await confirm_properties(
ctx,
"confirm_certificate", "confirm_certificate",
"Confirm transaction", "Confirm transaction",
props, props,
@ -596,7 +555,6 @@ async def confirm_certificate(
async def confirm_stake_pool_parameters( async def confirm_stake_pool_parameters(
ctx: Context,
pool_parameters: messages.CardanoPoolParametersType, pool_parameters: messages.CardanoPoolParametersType,
network_id: int, network_id: int,
) -> None: ) -> None:
@ -605,7 +563,6 @@ async def confirm_stake_pool_parameters(
) )
percentage_formatted = str(float(margin_percentage)).rstrip("0").rstrip(".") percentage_formatted = str(float(margin_percentage)).rstrip("0").rstrip(".")
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_registration", "confirm_pool_registration",
"Confirm transaction", "Confirm transaction",
( (
@ -626,7 +583,6 @@ async def confirm_stake_pool_parameters(
async def confirm_stake_pool_owner( async def confirm_stake_pool_owner(
ctx: Context,
keychain: Keychain, keychain: Keychain,
owner: messages.CardanoPoolOwner, owner: messages.CardanoPoolOwner,
protocol_magic: int, protocol_magic: int,
@ -669,7 +625,6 @@ async def confirm_stake_pool_owner(
) )
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_owners", "confirm_pool_owners",
"Confirm transaction", "Confirm transaction",
props, props,
@ -678,12 +633,10 @@ async def confirm_stake_pool_owner(
async def confirm_stake_pool_metadata( async def confirm_stake_pool_metadata(
ctx: Context,
metadata: messages.CardanoPoolMetadataType | None, metadata: messages.CardanoPoolMetadataType | None,
) -> None: ) -> None:
if metadata is None: if metadata is None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_metadata", "confirm_pool_metadata",
"Confirm transaction", "Confirm transaction",
(("Pool has no metadata (anonymous pool)", None),), (("Pool has no metadata (anonymous pool)", None),),
@ -692,7 +645,6 @@ async def confirm_stake_pool_metadata(
return return
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_metadata", "confirm_pool_metadata",
"Confirm transaction", "Confirm transaction",
( (
@ -704,13 +656,11 @@ async def confirm_stake_pool_metadata(
async def confirm_stake_pool_registration_final( async def confirm_stake_pool_registration_final(
ctx: Context,
protocol_magic: int, protocol_magic: int,
ttl: int | None, ttl: int | None,
validity_interval_start: int | None, validity_interval_start: int | None,
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_final", "confirm_pool_final",
"Confirm transaction", "Confirm transaction",
( (
@ -725,7 +675,6 @@ async def confirm_stake_pool_registration_final(
async def confirm_withdrawal( async def confirm_withdrawal(
ctx: Context,
withdrawal: messages.CardanoTxWithdrawal, withdrawal: messages.CardanoTxWithdrawal,
address_bytes: bytes, address_bytes: bytes,
network_id: int, network_id: int,
@ -746,7 +695,6 @@ async def confirm_withdrawal(
props.append(("Amount:", format_coin_amount(withdrawal.amount, network_id))) props.append(("Amount:", format_coin_amount(withdrawal.amount, network_id)))
await confirm_properties( await confirm_properties(
ctx,
"confirm_withdrawal", "confirm_withdrawal",
"Confirm transaction", "Confirm transaction",
props, props,
@ -774,7 +722,6 @@ def _format_stake_credential(
async def confirm_cvote_registration_delegation( async def confirm_cvote_registration_delegation(
ctx: Context,
public_key: str, public_key: str,
weight: int, weight: int,
) -> None: ) -> None:
@ -786,7 +733,6 @@ async def confirm_cvote_registration_delegation(
props.append(("Weight:", str(weight))) props.append(("Weight:", str(weight)))
await confirm_properties( await confirm_properties(
ctx,
"confirm_cvote_registration_delegation", "confirm_cvote_registration_delegation",
title="Confirm transaction", title="Confirm transaction",
props=props, props=props,
@ -795,7 +741,6 @@ async def confirm_cvote_registration_delegation(
async def confirm_cvote_registration_payment_address( async def confirm_cvote_registration_payment_address(
ctx: Context,
payment_address: str, payment_address: str,
should_show_payment_warning: bool, should_show_payment_warning: bool,
) -> None: ) -> None:
@ -806,7 +751,6 @@ async def confirm_cvote_registration_payment_address(
if should_show_payment_warning: if should_show_payment_warning:
props.append((CVOTE_REWARD_ELIGIBILITY_WARNING, None)) props.append((CVOTE_REWARD_ELIGIBILITY_WARNING, None))
await confirm_properties( await confirm_properties(
ctx,
"confirm_cvote_registration_payment_address", "confirm_cvote_registration_payment_address",
title="Confirm transaction", title="Confirm transaction",
props=props, props=props,
@ -815,7 +759,6 @@ async def confirm_cvote_registration_payment_address(
async def confirm_cvote_registration( async def confirm_cvote_registration(
ctx: Context,
vote_public_key: str | None, vote_public_key: str | None,
staking_path: list[int], staking_path: list[int],
nonce: int, nonce: int,
@ -842,7 +785,6 @@ async def confirm_cvote_registration(
) )
await confirm_properties( await confirm_properties(
ctx,
"confirm_cvote_registration", "confirm_cvote_registration",
title="Confirm transaction", title="Confirm transaction",
props=props, props=props,
@ -850,9 +792,8 @@ async def confirm_cvote_registration(
) )
async def show_auxiliary_data_hash(ctx: Context, auxiliary_data_hash: bytes) -> None: async def show_auxiliary_data_hash(auxiliary_data_hash: bytes) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_auxiliary_data", "confirm_auxiliary_data",
"Confirm transaction", "Confirm transaction",
(("Auxiliary data hash:", auxiliary_data_hash),), (("Auxiliary data hash:", auxiliary_data_hash),),
@ -860,12 +801,9 @@ async def show_auxiliary_data_hash(ctx: Context, auxiliary_data_hash: bytes) ->
) )
async def confirm_token_minting( async def confirm_token_minting(policy_id: bytes, token: messages.CardanoToken) -> None:
ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None:
assert token.mint_amount is not None # _validate_token assert token.mint_amount is not None # _validate_token
await confirm_properties( await confirm_properties(
ctx,
"confirm_mint", "confirm_mint",
"Confirm transaction", "Confirm transaction",
( (
@ -885,9 +823,8 @@ async def confirm_token_minting(
) )
async def warn_tx_network_unverifiable(ctx: Context) -> None: async def warn_tx_network_unverifiable() -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"warning_no_outputs", "warning_no_outputs",
"Warning", "Warning",
"Transaction has no outputs, network cannot be verified.", "Transaction has no outputs, network cannot be verified.",
@ -895,9 +832,8 @@ async def warn_tx_network_unverifiable(ctx: Context) -> None:
) )
async def confirm_script_data_hash(ctx: Context, script_data_hash: bytes) -> None: async def confirm_script_data_hash(script_data_hash: bytes) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_script_data_hash", "confirm_script_data_hash",
"Confirm transaction", "Confirm transaction",
( (
@ -911,10 +847,9 @@ async def confirm_script_data_hash(ctx: Context, script_data_hash: bytes) -> Non
async def confirm_collateral_input( async def confirm_collateral_input(
ctx: Context, collateral_input: messages.CardanoTxCollateralInput collateral_input: messages.CardanoTxCollateralInput,
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_collateral_input", "confirm_collateral_input",
"Confirm transaction", "Confirm transaction",
( (
@ -926,10 +861,9 @@ async def confirm_collateral_input(
async def confirm_reference_input( async def confirm_reference_input(
ctx: Context, reference_input: messages.CardanoTxReferenceInput reference_input: messages.CardanoTxReferenceInput,
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_reference_input", "confirm_reference_input",
"Confirm transaction", "Confirm transaction",
( (
@ -941,7 +875,7 @@ async def confirm_reference_input(
async def confirm_required_signer( async def confirm_required_signer(
ctx: Context, required_signer: messages.CardanoTxRequiredSigner required_signer: messages.CardanoTxRequiredSigner,
) -> None: ) -> None:
assert ( assert (
required_signer.key_hash is not None or required_signer.key_path required_signer.key_hash is not None or required_signer.key_path
@ -953,7 +887,6 @@ async def confirm_required_signer(
) )
await confirm_properties( await confirm_properties(
ctx,
"confirm_required_signer", "confirm_required_signer",
"Confirm transaction", "Confirm transaction",
(("Required signer", formatted_signer),), (("Required signer", formatted_signer),),
@ -962,7 +895,6 @@ async def confirm_required_signer(
async def show_cardano_address( async def show_cardano_address(
ctx: Context,
address_parameters: messages.CardanoAddressParametersType, address_parameters: messages.CardanoAddressParametersType,
address: str, address: str,
protocol_magic: int, protocol_magic: int,
@ -989,7 +921,6 @@ async def show_cardano_address(
path = address_n_to_str(address_parameters.address_n_staking) path = address_n_to_str(address_parameters.address_n_staking)
await layouts.show_address( await layouts.show_address(
ctx,
address, address,
path=path, path=path,
account=account, account=account,

@ -27,7 +27,7 @@ if TYPE_CHECKING:
) )
MsgIn = TypeVar("MsgIn", bound=CardanoMessages) MsgIn = TypeVar("MsgIn", bound=CardanoMessages)
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]] HandlerWithKeychain = Callable[[MsgIn, "Keychain"], Awaitable[MsgOut]]
class Keychain: class Keychain:
@ -136,9 +136,7 @@ def derive_and_store_secrets(passphrase: str) -> None:
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39( async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
from apps.common.seed import derive_and_store_roots from apps.common.seed import derive_and_store_roots
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
@ -146,7 +144,7 @@ async def _get_keychain_bip39(
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if derivation_type == CardanoDerivationType.LEDGER: if derivation_type == CardanoDerivationType.LEDGER:
seed = await get_seed(ctx) seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed)) return Keychain(cardano.from_seed_ledger(seed))
if not cache.get(cache.APP_COMMON_DERIVE_CARDANO): if not cache.get(cache.APP_COMMON_DERIVE_CARDANO):
@ -160,7 +158,7 @@ async def _get_keychain_bip39(
# _get_secret # _get_secret
secret = cache.get(cache_entry) secret = cache.get(cache_entry)
if secret is None: if secret is None:
await derive_and_store_roots(ctx) await derive_and_store_roots()
secret = cache.get(cache_entry) secret = cache.get(cache_entry)
assert secret is not None assert secret is not None
@ -168,20 +166,18 @@ async def _get_keychain_bip39(
return Keychain(root) return Keychain(root)
async def _get_keychain( async def _get_keychain(derivation_type: CardanoDerivationType) -> Keychain:
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
if mnemonic.is_bip39(): if mnemonic.is_bip39():
return await _get_keychain_bip39(ctx, derivation_type) return await _get_keychain_bip39(derivation_type)
else: else:
# derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0022.md # derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0022.md
seed = await get_seed(ctx) seed = await get_seed()
return Keychain(cardano.from_seed_slip23(seed)) return Keychain(cardano.from_seed_slip23(seed))
def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(msg: MsgIn) -> MsgOut:
keychain = await _get_keychain(ctx, msg.derivation_type) keychain = await _get_keychain(msg.derivation_type)
return await func(ctx, msg, keychain) return await func(msg, keychain)
return wrapper return wrapper

@ -4,13 +4,12 @@ from .. import seed
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Type from typing import Type
from trezor.wire import Context
from trezor.messages import CardanoSignTxFinished, CardanoSignTxInit from trezor.messages import CardanoSignTxFinished, CardanoSignTxInit
@seed.with_keychain @seed.with_keychain
async def sign_tx( async def sign_tx(
ctx: Context, msg: CardanoSignTxInit, keychain: seed.Keychain msg: CardanoSignTxInit, keychain: seed.Keychain
) -> CardanoSignTxFinished: ) -> CardanoSignTxFinished:
from trezor.messages import CardanoSignTxFinished from trezor.messages import CardanoSignTxFinished
from trezor import log, wire from trezor import log, wire
@ -40,7 +39,7 @@ async def sign_tx(
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
signer = signer_type(ctx, msg, keychain) signer = signer_type(msg, keychain)
try: try:
await signer.sign() await signer.sign()

@ -33,7 +33,6 @@ class MultisigSigner(Signer):
# super() omitted intentionally # super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx,
msg.fee, msg.fee,
msg.network_id, msg.network_id,
msg.protocol_magic, msg.protocol_magic,

@ -34,7 +34,6 @@ class OrdinarySigner(Signer):
# super() omitted intentionally # super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx,
msg.fee, msg.fee,
msg.network_id, msg.network_id,
msg.protocol_magic, msg.protocol_magic,
@ -92,10 +91,10 @@ class OrdinarySigner(Signer):
is_minting = SCHEMA_MINT.match(witness_path) is_minting = SCHEMA_MINT.match(witness_path)
if is_minting: if is_minting:
await layout.confirm_witness_request(self.ctx, witness_path) await layout.confirm_witness_request(witness_path)
elif not is_payment and not is_staking: elif not is_payment and not is_staking:
await self._fail_or_warn_path(witness_path, WITNESS_PATH_NAME) await self._fail_or_warn_path(witness_path, WITNESS_PATH_NAME)
else: else:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_witness_request(self.ctx, witness_path) layout.confirm_witness_request(witness_path)
) )

@ -22,12 +22,12 @@ class PlutusSigner(Signer):
# These items should be present if a Plutus script is to be executed. # These items should be present if a Plutus script is to be executed.
if self.msg.script_data_hash is None: if self.msg.script_data_hash is None:
await layout.warn_no_script_data_hash(self.ctx) await layout.warn_no_script_data_hash()
if self.msg.collateral_inputs_count == 0: if self.msg.collateral_inputs_count == 0:
await layout.warn_no_collateral_inputs(self.ctx) await layout.warn_no_collateral_inputs()
if self.msg.total_collateral is None: if self.msg.total_collateral is None:
await layout.warn_unknown_total_collateral(self.ctx) await layout.warn_unknown_total_collateral()
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
msg = self.msg # local_cache_attribute msg = self.msg # local_cache_attribute
@ -38,7 +38,6 @@ class PlutusSigner(Signer):
# tedious to check one by one on the Trezor screen). # tedious to check one by one on the Trezor screen).
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx,
msg.fee, msg.fee,
msg.network_id, msg.network_id,
msg.protocol_magic, msg.protocol_magic,
@ -52,7 +51,7 @@ class PlutusSigner(Signer):
async def _show_input(self, input: messages.CardanoTxInput) -> None: async def _show_input(self, input: messages.CardanoTxInput) -> None:
# super() omitted intentionally # super() omitted intentionally
# The inputs are not interchangeable (because of datums), so we must show them. # The inputs are not interchangeable (because of datums), so we must show them.
await self._show_if_showing_details(layout.confirm_input(self.ctx, input)) await self._show_if_showing_details(layout.confirm_input(input))
async def _show_output_credentials( async def _show_output_credentials(
self, address_parameters: messages.CardanoAddressParametersType self, address_parameters: messages.CardanoAddressParametersType
@ -64,7 +63,6 @@ class PlutusSigner(Signer):
# evaluation. We at least hide the staking path if it matches the payment path. # evaluation. We at least hide the staking path if it matches the payment path.
show_both_credentials = should_show_credentials(address_parameters) show_both_credentials = should_show_credentials(address_parameters)
await layout.show_device_owned_output_credentials( await layout.show_device_owned_output_credentials(
self.ctx,
Credential.payment_credential(address_parameters), Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
show_both_credentials, show_both_credentials,

@ -45,7 +45,6 @@ class PoolOwnerSigner(Signer):
# super() omitted intentionally # super() omitted intentionally
await layout.confirm_stake_pool_registration_final( await layout.confirm_stake_pool_registration_final(
self.ctx,
self.msg.protocol_magic, self.msg.protocol_magic,
self.msg.ttl, self.msg.ttl,
self.msg.validity_interval_start, self.msg.validity_interval_start,

@ -9,6 +9,7 @@ from trezor.enums import (
) )
from trezor.messages import CardanoTxItemAck, CardanoTxOutput from trezor.messages import CardanoTxItemAck, CardanoTxOutput
from trezor.wire import DataError, ProcessError from trezor.wire import DataError, ProcessError
from trezor.wire.context import call as ctx_call
from apps.common import safety_checks from apps.common import safety_checks
@ -21,7 +22,6 @@ from ..helpers.utils import derive_public_key
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Awaitable, ClassVar from typing import Any, Awaitable, ClassVar
from trezor.wire import Context
from trezor.enums import CardanoAddressType from trezor.enums import CardanoAddressType
from apps.common.paths import PathSchema from apps.common.paths import PathSchema
from apps.common import cbor from apps.common import cbor
@ -76,13 +76,11 @@ class Signer:
def __init__( def __init__(
self, self,
ctx: Context,
msg: messages.CardanoSignTxInit, msg: messages.CardanoSignTxInit,
keychain: seed.Keychain, keychain: seed.Keychain,
) -> None: ) -> None:
from ..helpers.account_path_check import AccountPathChecker from ..helpers.account_path_check import AccountPathChecker
self.ctx = ctx
self.msg = msg self.msg = msg
self.keychain = keychain self.keychain = keychain
@ -124,8 +122,8 @@ class Signer:
await self._confirm_tx(tx_hash) await self._confirm_tx(tx_hash)
response_after_witness_requests = await self._process_witness_requests(tx_hash) response_after_witness_requests = await self._process_witness_requests(tx_hash)
await self.ctx.call(response_after_witness_requests, messages.CardanoTxHostAck) await ctx_call(response_after_witness_requests, messages.CardanoTxHostAck)
await self.ctx.call( await ctx_call(
messages.CardanoTxBodyHash(tx_hash=tx_hash), messages.CardanoTxHostAck messages.CardanoTxBodyHash(tx_hash=tx_hash), messages.CardanoTxHostAck
) )
@ -225,12 +223,10 @@ class Signer:
validate_network_info(msg.network_id, msg.protocol_magic) validate_network_info(msg.network_id, msg.protocol_magic)
async def _show_tx_init(self) -> None: async def _show_tx_init(self) -> None:
self.should_show_details = await layout.show_tx_init( self.should_show_details = await layout.show_tx_init(self.SIGNING_MODE_TITLE)
self.ctx, self.SIGNING_MODE_TITLE
)
if not self._is_network_id_verifiable(): if not self._is_network_id_verifiable():
await layout.warn_tx_network_unverifiable(self.ctx) await layout.warn_tx_network_unverifiable()
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
# Final signing confirmation is handled separately in each signing mode. # Final signing confirmation is handled separately in each signing mode.
@ -242,7 +238,7 @@ class Signer:
self, inputs_list: HashBuilderList[tuple[bytes, int]] self, inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None: ) -> None:
for _ in range(self.msg.inputs_count): for _ in range(self.msg.inputs_count):
input: messages.CardanoTxInput = await self.ctx.call( input: messages.CardanoTxInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxInput CardanoTxItemAck(), messages.CardanoTxInput
) )
self._validate_input(input) self._validate_input(input)
@ -262,7 +258,7 @@ class Signer:
async def _process_outputs(self, outputs_list: HashBuilderList) -> None: async def _process_outputs(self, outputs_list: HashBuilderList) -> None:
total_amount = 0 total_amount = 0
for _ in range(self.msg.outputs_count): for _ in range(self.msg.outputs_count):
output: CardanoTxOutput = await self.ctx.call( output: CardanoTxOutput = await ctx_call(
CardanoTxItemAck(), CardanoTxOutput CardanoTxItemAck(), CardanoTxOutput
) )
await self._process_output(outputs_list, output) await self._process_output(outputs_list, output)
@ -346,10 +342,10 @@ class Signer:
and output.inline_datum_size == 0 and output.inline_datum_size == 0
and address_type in addresses.ADDRESS_TYPES_PAYMENT_SCRIPT and address_type in addresses.ADDRESS_TYPES_PAYMENT_SCRIPT
): ):
await layout.warn_tx_output_no_datum(self.ctx) await layout.warn_tx_output_no_datum()
if output.asset_groups_count > 0: if output.asset_groups_count > 0:
await layout.warn_tx_output_contains_tokens(self.ctx) await layout.warn_tx_output_contains_tokens()
if output.address_parameters is not None: if output.address_parameters is not None:
address = addresses.derive_human_readable( address = addresses.derive_human_readable(
@ -364,7 +360,6 @@ class Signer:
address = output.address address = output.address
await layout.confirm_sending( await layout.confirm_sending(
self.ctx,
output.amount, output.amount,
address, address,
"change" if self._is_change_output(output) else "address", "change" if self._is_change_output(output) else "address",
@ -375,7 +370,6 @@ class Signer:
self, address_parameters: messages.CardanoAddressParametersType self, address_parameters: messages.CardanoAddressParametersType
) -> None: ) -> None:
await layout.show_change_output_credentials( await layout.show_change_output_credentials(
self.ctx,
Credential.payment_credential(address_parameters), Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
) )
@ -439,7 +433,7 @@ class Signer:
if output.datum_hash is not None: if output.datum_hash is not None:
if should_show: if should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_datum_hash(self.ctx, output.datum_hash) layout.confirm_datum_hash(output.datum_hash)
) )
output_list.append(output.datum_hash) output_list.append(output.datum_hash)
@ -472,7 +466,7 @@ class Signer:
if output.datum_hash is not None: if output.datum_hash is not None:
if should_show: if should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_datum_hash(self.ctx, output.datum_hash) layout.confirm_datum_hash(output.datum_hash)
) )
add( add(
_BABBAGE_OUTPUT_KEY_DATUM_OPTION, _BABBAGE_OUTPUT_KEY_DATUM_OPTION,
@ -532,7 +526,7 @@ class Signer:
should_show_tokens: bool, should_show_tokens: bool,
) -> None: ) -> None:
for _ in range(asset_groups_count): for _ in range(asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call( asset_group: messages.CardanoAssetGroup = await ctx_call(
CardanoTxItemAck(), messages.CardanoAssetGroup CardanoTxItemAck(), messages.CardanoAssetGroup
) )
self._validate_asset_group(asset_group) self._validate_asset_group(asset_group)
@ -573,12 +567,12 @@ class Signer:
should_show_tokens: bool, should_show_tokens: bool,
) -> None: ) -> None:
for _ in range(tokens_count): for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call( token: messages.CardanoToken = await ctx_call(
CardanoTxItemAck(), messages.CardanoToken CardanoTxItemAck(), messages.CardanoToken
) )
self._validate_token(token) self._validate_token(token)
if should_show_tokens: if should_show_tokens:
await layout.confirm_sending_token(self.ctx, policy_id, token) await layout.confirm_sending_token(policy_id, token)
assert token.amount is not None # _validate_token assert token.amount is not None # _validate_token
tokens_dict.add(token.asset_name_bytes, token.amount) tokens_dict.add(token.asset_name_bytes, token.amount)
@ -614,7 +608,7 @@ class Signer:
chunks_count = self._get_chunks_count(inline_datum_size) chunks_count = self._get_chunks_count(inline_datum_size)
for chunk_number in range(chunks_count): for chunk_number in range(chunks_count):
chunk: messages.CardanoTxInlineDatumChunk = await self.ctx.call( chunk: messages.CardanoTxInlineDatumChunk = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk
) )
self._validate_chunk( self._validate_chunk(
@ -625,7 +619,7 @@ class Signer:
) )
if chunk_number == 0 and should_show: if chunk_number == 0 and should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_inline_datum(self.ctx, chunk.data, inline_datum_size) layout.confirm_inline_datum(chunk.data, inline_datum_size)
) )
inline_datum_cbor.add(chunk.data) inline_datum_cbor.add(chunk.data)
@ -641,7 +635,7 @@ class Signer:
chunks_count = self._get_chunks_count(reference_script_size) chunks_count = self._get_chunks_count(reference_script_size)
for chunk_number in range(chunks_count): for chunk_number in range(chunks_count):
chunk: messages.CardanoTxReferenceScriptChunk = await self.ctx.call( chunk: messages.CardanoTxReferenceScriptChunk = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk
) )
self._validate_chunk( self._validate_chunk(
@ -652,9 +646,7 @@ class Signer:
) )
if chunk_number == 0 and should_show: if chunk_number == 0 and should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_reference_script( layout.confirm_reference_script(chunk.data, reference_script_size)
self.ctx, chunk.data, reference_script_size
)
) )
reference_script_cbor.add(chunk.data) reference_script_cbor.add(chunk.data)
@ -662,7 +654,7 @@ class Signer:
async def _process_certificates(self, certificates_list: HashBuilderList) -> None: async def _process_certificates(self, certificates_list: HashBuilderList) -> None:
for _ in range(self.msg.certificates_count): for _ in range(self.msg.certificates_count):
certificate: messages.CardanoTxCertificate = await self.ctx.call( certificate: messages.CardanoTxCertificate = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxCertificate CardanoTxItemAck(), messages.CardanoTxCertificate
) )
self._validate_certificate(certificate) self._validate_certificate(certificate)
@ -726,13 +718,13 @@ class Signer:
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
assert certificate.pool_parameters is not None assert certificate.pool_parameters is not None
await layout.confirm_stake_pool_parameters( await layout.confirm_stake_pool_parameters(
self.ctx, certificate.pool_parameters, self.msg.network_id certificate.pool_parameters, self.msg.network_id
) )
await layout.confirm_stake_pool_metadata( await layout.confirm_stake_pool_metadata(
self.ctx, certificate.pool_parameters.metadata certificate.pool_parameters.metadata
) )
else: else:
await layout.confirm_certificate(self.ctx, certificate) await layout.confirm_certificate(certificate)
# pool owners # pool owners
@ -741,7 +733,7 @@ class Signer:
) -> None: ) -> None:
owners_as_path_count = 0 owners_as_path_count = 0
for _ in range(owners_count): for _ in range(owners_count):
owner: messages.CardanoPoolOwner = await self.ctx.call( owner: messages.CardanoPoolOwner = await ctx_call(
CardanoTxItemAck(), messages.CardanoPoolOwner CardanoTxItemAck(), messages.CardanoPoolOwner
) )
certificates.validate_pool_owner(owner, self.account_path_checker) certificates.validate_pool_owner(owner, self.account_path_checker)
@ -764,7 +756,7 @@ class Signer:
) )
await layout.confirm_stake_pool_owner( await layout.confirm_stake_pool_owner(
self.ctx, self.keychain, owner, self.msg.protocol_magic, self.msg.network_id self.keychain, owner, self.msg.protocol_magic, self.msg.network_id
) )
# pool relays # pool relays
@ -775,7 +767,7 @@ class Signer:
relays_count: int, relays_count: int,
) -> None: ) -> None:
for _ in range(relays_count): for _ in range(relays_count):
relay: messages.CardanoPoolRelayParameters = await self.ctx.call( relay: messages.CardanoPoolRelayParameters = await ctx_call(
CardanoTxItemAck(), messages.CardanoPoolRelayParameters CardanoTxItemAck(), messages.CardanoPoolRelayParameters
) )
certificates.validate_pool_relay(relay) certificates.validate_pool_relay(relay)
@ -787,14 +779,14 @@ class Signer:
self, withdrawals_dict: HashBuilderDict[bytes, int] self, withdrawals_dict: HashBuilderDict[bytes, int]
) -> None: ) -> None:
for _ in range(self.msg.withdrawals_count): for _ in range(self.msg.withdrawals_count):
withdrawal: messages.CardanoTxWithdrawal = await self.ctx.call( withdrawal: messages.CardanoTxWithdrawal = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxWithdrawal CardanoTxItemAck(), messages.CardanoTxWithdrawal
) )
self._validate_withdrawal(withdrawal) self._validate_withdrawal(withdrawal)
address_bytes = self._derive_withdrawal_address_bytes(withdrawal) address_bytes = self._derive_withdrawal_address_bytes(withdrawal)
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_withdrawal( layout.confirm_withdrawal(
self.ctx, withdrawal, address_bytes, self.msg.network_id withdrawal, address_bytes, self.msg.network_id
) )
) )
withdrawals_dict.add(address_bytes, withdrawal.amount) withdrawals_dict.add(address_bytes, withdrawal.amount)
@ -821,7 +813,7 @@ class Signer:
msg = self.msg # local_cache_attribute msg = self.msg # local_cache_attribute
data: messages.CardanoTxAuxiliaryData = await self.ctx.call( data: messages.CardanoTxAuxiliaryData = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxAuxiliaryData CardanoTxItemAck(), messages.CardanoTxAuxiliaryData
) )
auxiliary_data.validate(data, msg.protocol_magic, msg.network_id) auxiliary_data.validate(data, msg.protocol_magic, msg.network_id)
@ -833,7 +825,6 @@ class Signer:
self.keychain, data, msg.protocol_magic, msg.network_id self.keychain, data, msg.protocol_magic, msg.network_id
) )
await auxiliary_data.show( await auxiliary_data.show(
self.ctx,
self.keychain, self.keychain,
auxiliary_data_hash, auxiliary_data_hash,
data.cvote_registration_parameters, data.cvote_registration_parameters,
@ -843,21 +834,21 @@ class Signer:
) )
self.tx_dict.add(_TX_BODY_KEY_AUXILIARY_DATA, auxiliary_data_hash) self.tx_dict.add(_TX_BODY_KEY_AUXILIARY_DATA, auxiliary_data_hash)
await self.ctx.call(auxiliary_data_supplement, messages.CardanoTxHostAck) await ctx_call(auxiliary_data_supplement, messages.CardanoTxHostAck)
# minting # minting
async def _process_minting( async def _process_minting(
self, minting_dict: HashBuilderDict[bytes, HashBuilderDict] self, minting_dict: HashBuilderDict[bytes, HashBuilderDict]
) -> None: ) -> None:
token_minting: messages.CardanoTxMint = await self.ctx.call( token_minting: messages.CardanoTxMint = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxMint CardanoTxItemAck(), messages.CardanoTxMint
) )
await layout.warn_tx_contains_mint(self.ctx) await layout.warn_tx_contains_mint()
for _ in range(token_minting.asset_groups_count): for _ in range(token_minting.asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call( asset_group: messages.CardanoAssetGroup = await ctx_call(
CardanoTxItemAck(), messages.CardanoAssetGroup CardanoTxItemAck(), messages.CardanoAssetGroup
) )
self._validate_asset_group(asset_group, is_mint=True) self._validate_asset_group(asset_group, is_mint=True)
@ -881,11 +872,11 @@ class Signer:
tokens_count: int, tokens_count: int,
) -> None: ) -> None:
for _ in range(tokens_count): for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call( token: messages.CardanoToken = await ctx_call(
CardanoTxItemAck(), messages.CardanoToken CardanoTxItemAck(), messages.CardanoToken
) )
self._validate_token(token, is_mint=True) self._validate_token(token, is_mint=True)
await layout.confirm_token_minting(self.ctx, policy_id, token) await layout.confirm_token_minting(policy_id, token)
assert token.mint_amount is not None # _validate_token assert token.mint_amount is not None # _validate_token
tokens.add(token.asset_name_bytes, token.mint_amount) tokens.add(token.asset_name_bytes, token.mint_amount)
@ -896,7 +887,7 @@ class Signer:
assert self.msg.script_data_hash is not None assert self.msg.script_data_hash is not None
self._validate_script_data_hash() self._validate_script_data_hash()
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_script_data_hash(self.ctx, self.msg.script_data_hash) layout.confirm_script_data_hash(self.msg.script_data_hash)
) )
self.tx_dict.add(_TX_BODY_KEY_SCRIPT_DATA_HASH, self.msg.script_data_hash) self.tx_dict.add(_TX_BODY_KEY_SCRIPT_DATA_HASH, self.msg.script_data_hash)
@ -913,7 +904,7 @@ class Signer:
self, collateral_inputs_list: HashBuilderList[tuple[bytes, int]] self, collateral_inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None: ) -> None:
for _ in range(self.msg.collateral_inputs_count): for _ in range(self.msg.collateral_inputs_count):
collateral_input: messages.CardanoTxCollateralInput = await self.ctx.call( collateral_input: messages.CardanoTxCollateralInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxCollateralInput CardanoTxItemAck(), messages.CardanoTxCollateralInput
) )
self._validate_collateral_input(collateral_input) self._validate_collateral_input(collateral_input)
@ -933,7 +924,7 @@ class Signer:
) -> None: ) -> None:
if self.msg.total_collateral is None: if self.msg.total_collateral is None:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_collateral_input(self.ctx, collateral_input) layout.confirm_collateral_input(collateral_input)
) )
# required signers # required signers
@ -944,12 +935,12 @@ class Signer:
from ..helpers.utils import get_public_key_hash from ..helpers.utils import get_public_key_hash
for _ in range(self.msg.required_signers_count): for _ in range(self.msg.required_signers_count):
required_signer: messages.CardanoTxRequiredSigner = await self.ctx.call( required_signer: messages.CardanoTxRequiredSigner = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxRequiredSigner CardanoTxItemAck(), messages.CardanoTxRequiredSigner
) )
self._validate_required_signer(required_signer) self._validate_required_signer(required_signer)
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_required_signer(self.ctx, required_signer) layout.confirm_required_signer(required_signer)
) )
key_hash = required_signer.key_hash or get_public_key_hash( key_hash = required_signer.key_hash or get_public_key_hash(
@ -985,9 +976,7 @@ class Signer:
# collateral return # collateral return
async def _process_collateral_return(self) -> None: async def _process_collateral_return(self) -> None:
output: CardanoTxOutput = await self.ctx.call( output: CardanoTxOutput = await ctx_call(CardanoTxItemAck(), CardanoTxOutput)
CardanoTxItemAck(), CardanoTxOutput
)
self._validate_collateral_return(output) self._validate_collateral_return(output)
should_show_init = self._should_show_collateral_return_init(output) should_show_init = self._should_show_collateral_return_init(output)
should_show_tokens = self._should_show_collateral_return_tokens(output) should_show_tokens = self._should_show_collateral_return_tokens(output)
@ -1031,9 +1020,7 @@ class Signer:
# We don't display missing datum warning since datums are forbidden. # We don't display missing datum warning since datums are forbidden.
if output.asset_groups_count > 0: if output.asset_groups_count > 0:
await layout.warn_tx_output_contains_tokens( await layout.warn_tx_output_contains_tokens(is_collateral_return=True)
self.ctx, is_collateral_return=True
)
if output.address_parameters is not None: if output.address_parameters is not None:
address = addresses.derive_human_readable( address = addresses.derive_human_readable(
@ -1050,7 +1037,6 @@ class Signer:
address = output.address address = output.address
await layout.confirm_sending( await layout.confirm_sending(
self.ctx,
output.amount, output.amount,
address, address,
"collateral-return", "collateral-return",
@ -1078,12 +1064,12 @@ class Signer:
self, reference_inputs_list: HashBuilderList[tuple[bytes, int]] self, reference_inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None: ) -> None:
for _ in range(self.msg.reference_inputs_count): for _ in range(self.msg.reference_inputs_count):
reference_input: messages.CardanoTxReferenceInput = await self.ctx.call( reference_input: messages.CardanoTxReferenceInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxReferenceInput CardanoTxItemAck(), messages.CardanoTxReferenceInput
) )
self._validate_reference_input(reference_input) self._validate_reference_input(reference_input)
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_reference_input(self.ctx, reference_input) layout.confirm_reference_input(reference_input)
) )
reference_inputs_list.append( reference_inputs_list.append(
(reference_input.prev_hash, reference_input.prev_index) (reference_input.prev_hash, reference_input.prev_index)
@ -1101,9 +1087,7 @@ class Signer:
response: CardanoTxResponseType = CardanoTxItemAck() response: CardanoTxResponseType = CardanoTxItemAck()
for _ in range(self.msg.witness_requests_count): for _ in range(self.msg.witness_requests_count):
witness_request = await self.ctx.call( witness_request = await ctx_call(response, messages.CardanoTxWitnessRequest)
response, messages.CardanoTxWitnessRequest
)
self._validate_witness_request(witness_request) self._validate_witness_request(witness_request)
path = witness_request.path path = witness_request.path
await self._show_witness_request(path) await self._show_witness_request(path)
@ -1123,7 +1107,7 @@ class Signer:
self, self,
witness_path: list[int], witness_path: list[int],
) -> None: ) -> None:
await layout.confirm_witness_request(self.ctx, witness_path) await layout.confirm_witness_request(witness_path)
# helpers # helpers
@ -1241,7 +1225,7 @@ class Signer:
if safety_checks.is_strict(): if safety_checks.is_strict():
raise DataError(f"Invalid {path_name.lower()}") raise DataError(f"Invalid {path_name.lower()}")
else: else:
await layout.warn_path(self.ctx, path, path_name) await layout.warn_path(path, path_name)
def _fail_if_strict_and_unusual( def _fail_if_strict_and_unusual(
self, address_parameters: messages.CardanoAddressParametersType self, address_parameters: messages.CardanoAddressParametersType

@ -16,7 +16,6 @@ if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
from .seed import Slip21Node from .seed import Slip21Node
T = TypeVar("T") T = TypeVar("T")
@ -36,8 +35,8 @@ if TYPE_CHECKING:
MsgIn = TypeVar("MsgIn", bound=MessageType) MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType) MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]] Handler = Callable[[MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]] HandlerWithKeychain = Callable[[MsgIn, "Keychain"], Awaitable[MsgOut]]
class Deletable(Protocol): class Deletable(Protocol):
def __del__(self) -> None: def __del__(self) -> None:
@ -176,14 +175,13 @@ class Keychain:
async def get_keychain( async def get_keychain(
ctx: Context,
curve: str, curve: str,
schemas: Iterable[paths.PathSchemaType], schemas: Iterable[paths.PathSchemaType],
slip21_namespaces: Iterable[paths.Slip21Path] = (), slip21_namespaces: Iterable[paths.Slip21Path] = (),
) -> Keychain: ) -> Keychain:
from .seed import get_seed from .seed import get_seed
seed = await get_seed(ctx) seed = await get_seed()
keychain = Keychain(seed, curve, schemas, slip21_namespaces) keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain return keychain
@ -205,10 +203,10 @@ def with_slip44_keychain(
schemas = [s.copy() for s in schemas] schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut: async def wrapper(msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, curve, schemas) keychain = await get_keychain(curve, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain) return await func(msg, keychain)
return wrapper return wrapper

@ -1,12 +1,8 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device import storage.device as storage_device
from trezor.wire import DataError from trezor.wire import DataError
if TYPE_CHECKING:
from trezor.wire import Context
_MAX_PASSPHRASE_LEN = const(50) _MAX_PASSPHRASE_LEN = const(50)
@ -14,7 +10,7 @@ def is_enabled() -> bool:
return storage_device.is_passphrase_enabled() return storage_device.is_passphrase_enabled()
async def get(ctx: Context) -> str: async def get() -> str:
from trezor import workflow from trezor import workflow
if not is_enabled(): if not is_enabled():
@ -24,23 +20,24 @@ async def get(ctx: Context) -> str:
if storage_device.get_passphrase_always_on_device(): if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device from trezor.ui.layouts import request_passphrase_on_device
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN) passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
else: else:
passphrase = await _request_on_host(ctx) passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes") raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase return passphrase
async def _request_on_host(ctx: Context) -> str: async def _request_on_host() -> str:
from trezor.messages import PassphraseAck, PassphraseRequest from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host() request_passphrase_on_host()
request = PassphraseRequest() request = PassphraseRequest()
ack = await ctx.call(request, PassphraseAck) ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute passphrase = ack.passphrase # local_cache_attribute
if ack.on_device: if ack.on_device:
@ -48,7 +45,7 @@ async def _request_on_host(ctx: Context) -> str:
if passphrase is not None: if passphrase is not None:
raise DataError("Passphrase provided when it should not be") raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN) return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None: if passphrase is None:
raise DataError( raise DataError(
@ -63,14 +60,12 @@ async def _request_on_host(ctx: Context) -> str:
if storage_device.get_hide_passphrase_from_host(): if storage_device.get_hide_passphrase_from_host():
explanation = "Passphrase provided by host will be used but will not be displayed due to the device settings." explanation = "Passphrase provided by host will be used but will not be displayed due to the device settings."
await confirm_action( await confirm_action(
ctx,
"passphrase_host1_hidden", "passphrase_host1_hidden",
"Hidden wallet", "Hidden wallet",
description=f"Access hidden wallet?\n{explanation}", description=f"Access hidden wallet?\n{explanation}",
) )
else: else:
await confirm_action( await confirm_action(
ctx,
"passphrase_host1", "passphrase_host1",
"Hidden wallet", "Hidden wallet",
description="Next screen will show the passphrase.", description="Next screen will show the passphrase.",
@ -78,7 +73,6 @@ async def _request_on_host(ctx: Context) -> str:
) )
await confirm_blob( await confirm_blob(
ctx,
"passphrase_host2", "passphrase_host2",
"Confirm passphrase", "Confirm passphrase",
passphrase, passphrase,

@ -15,7 +15,6 @@ if TYPE_CHECKING:
TypeVar, TypeVar,
) )
from typing_extensions import Protocol from typing_extensions import Protocol
from trezor import wire
Bip32Path = Sequence[int] Bip32Path = Sequence[int]
Slip21Path = Sequence[bytes] Slip21Path = Sequence[bytes]
@ -342,20 +341,19 @@ PATTERN_CASA = "m/45'/coin_type/account/change/address_index"
async def validate_path( async def validate_path(
ctx: wire.Context,
keychain: KeychainValidatorType, keychain: KeychainValidatorType,
path: Bip32Path, path: Bip32Path,
*additional_checks: bool, *additional_checks: bool,
) -> None: ) -> None:
keychain.verify_path(path) keychain.verify_path(path)
if not keychain.is_in_keychain(path) or not all(additional_checks): if not keychain.is_in_keychain(path) or not all(additional_checks):
await show_path_warning(ctx, path) await show_path_warning(path)
async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None: async def show_path_warning(path: Bip32Path) -> None:
from trezor.ui.layouts import confirm_path_warning from trezor.ui.layouts import confirm_path_warning
await confirm_path_warning(ctx, address_n_to_str(path)) await confirm_path_warning(address_n_to_str(path))
def is_hardened(i: int) -> bool: def is_hardened(i: int) -> bool:

@ -1,16 +1,12 @@
import utime import utime
from typing import TYPE_CHECKING from typing import Any, NoReturn
import storage.cache as storage_cache import storage.cache as storage_cache
from trezor import config, utils, wire from trezor import config, utils, wire
if TYPE_CHECKING:
from typing import Any, NoReturn
from trezor.wire import Context, GenericContext
async def _request_sd_salt( async def _request_sd_salt(
ctx: wire.GenericContext, raise_cancelled_on_unavailable: bool = False raise_cancelled_on_unavailable: bool = False,
) -> bytearray | None: ) -> bytearray | None:
"""Helper to get SD salt in a general manner, working for all models. """Helper to get SD salt in a general manner, working for all models.
@ -23,7 +19,7 @@ async def _request_sd_salt(
from .sdcard import request_sd_salt, SdCardUnavailable from .sdcard import request_sd_salt, SdCardUnavailable
try: try:
return await request_sd_salt(ctx) return await request_sd_salt()
except SdCardUnavailable: except SdCardUnavailable:
if raise_cancelled_on_unavailable: if raise_cancelled_on_unavailable:
raise wire.PinCancelled("SD salt is unavailable") raise wire.PinCancelled("SD salt is unavailable")
@ -43,38 +39,37 @@ def can_lock_device() -> bool:
async def request_pin( async def request_pin(
ctx: GenericContext,
prompt: str, prompt: str,
attempts_remaining: int | None = None, attempts_remaining: int | None = None,
allow_cancel: bool = True, allow_cancel: bool = True,
) -> str: ) -> str:
from trezor.ui.layouts import request_pin_on_device from trezor.ui.layouts import request_pin_on_device
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel) return await request_pin_on_device(prompt, attempts_remaining, allow_cancel)
async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str: async def request_pin_confirm(*args: Any, **kwargs: Any) -> str:
from trezor.ui.layouts import pin_mismatch_popup, confirm_reenter_pin from trezor.ui.layouts import pin_mismatch_popup, confirm_reenter_pin
while True: while True:
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs) pin1 = await request_pin("Enter new PIN", *args, **kwargs)
await confirm_reenter_pin(ctx) await confirm_reenter_pin()
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs) pin2 = await request_pin("Re-enter new PIN", *args, **kwargs)
if pin1 == pin2: if pin1 == pin2:
return pin1 return pin1
await pin_mismatch_popup(ctx) await pin_mismatch_popup()
async def request_pin_and_sd_salt( async def request_pin_and_sd_salt(
ctx: Context, prompt: str, allow_cancel: bool = True prompt: str, allow_cancel: bool = True
) -> tuple[str, bytearray | None]: ) -> tuple[str, bytearray | None]:
if config.has_pin(): if config.has_pin():
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel) pin = await request_pin(prompt, config.get_pin_rem(), allow_cancel)
config.ensure_not_wipe_code(pin) config.ensure_not_wipe_code(pin)
else: else:
pin = "" pin = ""
salt = await _request_sd_salt(ctx) salt = await _request_sd_salt()
return pin, salt return pin, salt
@ -85,7 +80,6 @@ def _set_last_unlock_time() -> None:
async def verify_user_pin( async def verify_user_pin(
ctx: GenericContext = wire.DUMMY_CONTEXT,
prompt: str = "Enter PIN", prompt: str = "Enter PIN",
allow_cancel: bool = True, allow_cancel: bool = True,
retry: bool = True, retry: bool = True,
@ -107,14 +101,12 @@ async def verify_user_pin(
if config.has_pin(): if config.has_pin():
from trezor.ui.layouts import request_pin_on_device from trezor.ui.layouts import request_pin_on_device
pin = await request_pin_on_device( pin = await request_pin_on_device(prompt, config.get_pin_rem(), allow_cancel)
ctx, prompt, config.get_pin_rem(), allow_cancel
)
config.ensure_not_wipe_code(pin) config.ensure_not_wipe_code(pin)
else: else:
pin = "" pin = ""
salt = await _request_sd_salt(ctx, raise_cancelled_on_unavailable=True) salt = await _request_sd_salt(raise_cancelled_on_unavailable=True)
if config.unlock(pin, salt): if config.unlock(pin, salt):
_set_last_unlock_time() _set_last_unlock_time()
return return
@ -123,7 +115,7 @@ async def verify_user_pin(
while retry: while retry:
pin = await request_pin_on_device( # type: ignore ["request_pin_on_device" is possibly unbound] pin = await request_pin_on_device( # type: ignore ["request_pin_on_device" is possibly unbound]
ctx, "Enter PIN", config.get_pin_rem(), allow_cancel, wrong_pin=True "Enter PIN", config.get_pin_rem(), allow_cancel, wrong_pin=True
) )
if config.unlock(pin, salt): if config.unlock(pin, salt):
_set_last_unlock_time() _set_last_unlock_time()
@ -132,11 +124,10 @@ async def verify_user_pin(
raise wire.PinInvalid raise wire.PinInvalid
async def error_pin_invalid(ctx: Context) -> NoReturn: async def error_pin_invalid() -> NoReturn:
from trezor.ui.layouts import show_error_and_raise from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise( await show_error_and_raise(
ctx,
"warning_wrong_pin", "warning_wrong_pin",
"The PIN you have entered is not valid.", "The PIN you have entered is not valid.",
"Wrong PIN", # header "Wrong PIN", # header
@ -145,11 +136,10 @@ async def error_pin_invalid(ctx: Context) -> NoReturn:
assert False assert False
async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn: async def error_pin_matches_wipe_code() -> NoReturn:
from trezor.ui.layouts import show_error_and_raise from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise( await show_error_and_raise(
ctx,
"warning_invalid_new_pin", "warning_invalid_new_pin",
"The new PIN must be different from your wipe code.", "The new PIN must be different from your wipe code.",
"Invalid PIN", # header "Invalid PIN", # header

@ -7,10 +7,9 @@ class SdCardUnavailable(wire.ProcessError):
pass pass
async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None: async def _confirm_retry_wrong_card() -> None:
if SD_CARD_HOT_SWAPPABLE: if SD_CARD_HOT_SWAPPABLE:
await confirm_action( await confirm_action(
ctx,
"warning_wrong_sd", "warning_wrong_sd",
"SD card protection", "SD card protection",
"Wrong SD card.", "Wrong SD card.",
@ -21,7 +20,6 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
) )
else: else:
await show_error_and_raise( await show_error_and_raise(
ctx,
"warning_wrong_sd", "warning_wrong_sd",
"Please unplug the device and insert the correct SD card.", "Please unplug the device and insert the correct SD card.",
"Wrong SD card.", "Wrong SD card.",
@ -29,10 +27,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
) )
async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None: async def _confirm_retry_insert_card() -> None:
if SD_CARD_HOT_SWAPPABLE: if SD_CARD_HOT_SWAPPABLE:
await confirm_action( await confirm_action(
ctx,
"warning_no_sd", "warning_no_sd",
"SD card protection", "SD card protection",
"SD card required.", "SD card required.",
@ -43,7 +40,6 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
) )
else: else:
await show_error_and_raise( await show_error_and_raise(
ctx,
"warning_no_sd", "warning_no_sd",
"Please unplug the device and insert your SD card.", "Please unplug the device and insert your SD card.",
"SD card required.", "SD card required.",
@ -51,10 +47,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
) )
async def _confirm_format_card(ctx: wire.GenericContext) -> None: async def _confirm_format_card() -> None:
# Format card? yes/no # Format card? yes/no
await confirm_action( await confirm_action(
ctx,
"warning_format_sd", "warning_format_sd",
"SD card error", "SD card error",
"Unknown filesystem.", "Unknown filesystem.",
@ -66,7 +61,6 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
# Confirm formatting # Confirm formatting
await confirm_action( await confirm_action(
ctx,
"confirm_format_sd", "confirm_format_sd",
"Format SD card", "Format SD card",
"All data on the SD card will be lost.", "All data on the SD card will be lost.",
@ -79,11 +73,9 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
async def confirm_retry_sd( async def confirm_retry_sd(
ctx: wire.GenericContext,
exc: wire.ProcessError = SdCardUnavailable("Error accessing SD card."), exc: wire.ProcessError = SdCardUnavailable("Error accessing SD card."),
) -> None: ) -> None:
await confirm_action( await confirm_action(
ctx,
"warning_sd_retry", "warning_sd_retry",
"SD card problem", "SD card problem",
None, None,
@ -94,9 +86,7 @@ async def confirm_retry_sd(
) )
async def ensure_sdcard( async def ensure_sdcard(ensure_filesystem: bool = True) -> None:
ctx: wire.GenericContext, ensure_filesystem: bool = True
) -> None:
"""Ensure a SD card is ready for use. """Ensure a SD card is ready for use.
This function runs the UI flow needed to ask the user to insert a SD card if there This function runs the UI flow needed to ask the user to insert a SD card if there
@ -109,7 +99,7 @@ async def ensure_sdcard(
from trezor import sdcard from trezor import sdcard
while not sdcard.is_present(): while not sdcard.is_present():
await _confirm_retry_insert_card(ctx) await _confirm_retry_insert_card()
if not ensure_filesystem: if not ensure_filesystem:
return return
@ -126,7 +116,7 @@ async def ensure_sdcard(
# no error when mounting # no error when mounting
return return
await _confirm_format_card(ctx) await _confirm_format_card()
# Proceed to formatting. Failure is caught by the outside OSError handler # Proceed to formatting. Failure is caught by the outside OSError handler
with sdcard.filesystem(mounted=False): with sdcard.filesystem(mounted=False):
@ -139,26 +129,24 @@ async def ensure_sdcard(
except OSError: except OSError:
# formatting failed, or generic I/O error (SD card power-on failed) # formatting failed, or generic I/O error (SD card power-on failed)
await confirm_retry_sd(ctx) await confirm_retry_sd()
async def request_sd_salt( async def request_sd_salt() -> bytearray | None:
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
) -> bytearray | None:
import storage.sd_salt as storage_sd_salt import storage.sd_salt as storage_sd_salt
if not storage_sd_salt.is_enabled(): if not storage_sd_salt.is_enabled():
return None return None
while True: while True:
await ensure_sdcard(ctx, ensure_filesystem=False) await ensure_sdcard(ensure_filesystem=False)
try: try:
return storage_sd_salt.load_sd_salt() return storage_sd_salt.load_sd_salt()
except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem): except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
await _confirm_retry_wrong_card(ctx) await _confirm_retry_wrong_card()
except OSError: except OSError:
# Generic problem with loading the SD salt (hardware problem, or we could # Generic problem with loading the SD salt (hardware problem, or we could
# not read the file, or there is a staged salt which cannot be committed). # not read the file, or there is a staged salt which cannot be committed).
# In either case, there is no good way to recover. If the user clicks Retry, # In either case, there is no good way to recover. If the user clicks Retry,
# we will try again. # we will try again.
await confirm_retry_sd(ctx) await confirm_retry_sd()

@ -10,7 +10,6 @@ from .passphrase import get as get_passphrase
if TYPE_CHECKING: if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
from trezor.wire import Context
from trezor.crypto import bip32 from trezor.crypto import bip32
@ -50,7 +49,7 @@ if not utils.BITCOIN_ONLY:
# We want to derive both the normal seed and the Cardano seed together, AND # We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same # expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: Context) -> None: async def derive_and_store_roots() -> None:
from trezor import wire from trezor import wire
if not storage_device.is_initialized(): if not storage_device.is_initialized():
@ -64,7 +63,7 @@ if not utils.BITCOIN_ONLY:
if not need_seed and not need_cardano_secret: if not need_seed and not need_cardano_secret:
return return
passphrase = await get_passphrase(ctx) passphrase = await get_passphrase()
if need_seed: if need_seed:
common_seed = mnemonic.get_seed(passphrase) common_seed = mnemonic.get_seed(passphrase)
@ -76,8 +75,8 @@ if not utils.BITCOIN_ONLY:
derive_and_store_secrets(passphrase) derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED) @storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes: async def get_seed() -> bytes:
await derive_and_store_roots(ctx) await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None assert common_seed is not None
return common_seed return common_seed
@ -87,8 +86,8 @@ else:
# We use the simple version of `get_seed` that never needs to derive anything else. # We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED) @storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes: async def get_seed() -> bytes:
passphrase = await get_passphrase(ctx) passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase) return mnemonic.get_seed(passphrase)

@ -11,6 +11,7 @@ if __debug__:
from trezor import log, loop, utils, wire from trezor import log, loop, utils, wire
from trezor.ui import display from trezor.ui import display
from trezor.wire import context
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.messages import ( from trezor.messages import (
DebugLinkLayout, DebugLinkLayout,
@ -47,7 +48,7 @@ if __debug__:
layout_change_chan = loop.chan() layout_change_chan = loop.chan()
DEBUG_CONTEXT: wire.Context | None = None DEBUG_CONTEXT: context.Context | None = None
LAYOUT_WATCHER_NONE = 0 LAYOUT_WATCHER_NONE = 0
LAYOUT_WATCHER_STATE = 1 LAYOUT_WATCHER_STATE = 1
@ -139,9 +140,7 @@ if __debug__:
await DEBUG_CONTEXT.write(DebugLinkState(tokens=content_tokens)) await DEBUG_CONTEXT.write(DebugLinkState(tokens=content_tokens))
storage.layout_watcher = LAYOUT_WATCHER_NONE storage.layout_watcher = LAYOUT_WATCHER_NONE
async def dispatch_DebugLinkWatchLayout( async def dispatch_DebugLinkWatchLayout(msg: DebugLinkWatchLayout) -> Success:
ctx: wire.Context, msg: DebugLinkWatchLayout
) -> Success:
from trezor import ui from trezor import ui
layout_change_chan.putters.clear() layout_change_chan.putters.clear()
@ -152,16 +151,14 @@ if __debug__:
return Success() return Success()
async def dispatch_DebugLinkResetDebugEvents( async def dispatch_DebugLinkResetDebugEvents(
ctx: wire.Context, msg: DebugLinkResetDebugEvents msg: DebugLinkResetDebugEvents,
) -> Success: ) -> Success:
# Resetting the debug events makes sure that the previous # Resetting the debug events makes sure that the previous
# events/layouts are not mixed with the new ones. # events/layouts are not mixed with the new ones.
storage.reset_debug_events() storage.reset_debug_events()
return Success() return Success()
async def dispatch_DebugLinkDecision( async def dispatch_DebugLinkDecision(msg: DebugLinkDecision) -> None:
ctx: wire.Context, msg: DebugLinkDecision
) -> None:
from trezor import workflow from trezor import workflow
workflow.idle_timer.touch() workflow.idle_timer.touch()
@ -194,7 +191,7 @@ if __debug__:
loop.schedule(return_layout_change()) loop.schedule(return_layout_change())
async def dispatch_DebugLinkGetState( async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState msg: DebugLinkGetState,
) -> DebugLinkState | None: ) -> DebugLinkState | None:
from trezor.messages import DebugLinkState from trezor.messages import DebugLinkState
from apps.common import mnemonic, passphrase from apps.common import mnemonic, passphrase
@ -218,9 +215,7 @@ if __debug__:
return m return m
async def dispatch_DebugLinkRecordScreen( async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
ctx: wire.Context, msg: DebugLinkRecordScreen
) -> Success:
if msg.target_directory: if msg.target_directory:
# In case emulator is restarted but we still want to record screenshots # In case emulator is restarted but we still want to record screenshots
# into the same directory as before, we need to increment the refresh index, # into the same directory as before, we need to increment the refresh index,
@ -235,18 +230,14 @@ if __debug__:
return Success() return Success()
async def dispatch_DebugLinkReseedRandom( async def dispatch_DebugLinkReseedRandom(msg: DebugLinkReseedRandom) -> Success:
ctx: wire.Context, msg: DebugLinkReseedRandom
) -> Success:
if msg.value is not None: if msg.value is not None:
from trezor.crypto import random from trezor.crypto import random
random.reseed(msg.value) random.reseed(msg.value)
return Success() return Success()
async def dispatch_DebugLinkEraseSdCard( async def dispatch_DebugLinkEraseSdCard(msg: DebugLinkEraseSdCard) -> Success:
ctx: wire.Context, msg: DebugLinkEraseSdCard
) -> Success:
from trezor import io from trezor import io
sdcard = io.sdcard # local_cache_attribute sdcard = io.sdcard # local_cache_attribute
@ -271,8 +262,8 @@ if __debug__:
def boot() -> None: def boot() -> None:
register = workflow_handlers.register # local_cache_attribute register = workflow_handlers.register # local_cache_attribute
register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore [Argument of type "(ctx: Context, msg: DebugLinkDecision) -> Coroutine[Any, Any, None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"] register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore [Argument of type "(msg: DebugLinkDecision) -> Coroutine[Any, Any, None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) # type: ignore [Argument of type "(ctx: Context, msg: DebugLinkGetState) -> Coroutine[Any, Any, DebugLinkState | None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"] register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) # type: ignore [Argument of type "(msg: DebugLinkGetState) -> Coroutine[Any, Any, DebugLinkState | None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom) register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom)
register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen) register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen)
register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard) register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import LoadDevice, Success from trezor.messages import LoadDevice, Success
from trezor.wire import Context
async def load_device(ctx: Context, msg: LoadDevice) -> Success: async def load_device(msg: LoadDevice) -> Success:
import storage.device as storage_device import storage.device as storage_device
from trezor import config from trezor import config
from trezor.crypto import bip39, slip39 from trezor.crypto import bip39, slip39
@ -38,7 +37,6 @@ async def load_device(ctx: Context, msg: LoadDevice) -> Success:
# _warn # _warn
await confirm_action( await confirm_action(
ctx,
"warn_loading_seed", "warn_loading_seed",
"Loading seed", "Loading seed",
"Loading private seed is not recommended.", "Loading private seed is not recommended.",

@ -1,13 +1,12 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor import wire
from trezor.utils import Writer, HashWriter from trezor.utils import Writer, HashWriter
from trezor.messages import EosTxActionAck from trezor.messages import EosTxActionAck
async def process_action( async def process_action(
ctx: wire.Context, sha: HashWriter, action: EosTxActionAck, is_last: bool sha: HashWriter, action: EosTxActionAck, is_last: bool
) -> None: ) -> None:
from .. import helpers, writers from .. import helpers, writers
from . import layout from . import layout
@ -26,71 +25,70 @@ async def process_action(
if account == "eosio": if account == "eosio":
if name == "buyram": if name == "buyram":
assert action.buy_ram is not None # _check_action assert action.buy_ram is not None # _check_action
await layout.confirm_action_buyram(ctx, action.buy_ram) await layout.confirm_action_buyram(action.buy_ram)
writers.write_action_buyram(w, action.buy_ram) writers.write_action_buyram(w, action.buy_ram)
elif name == "buyrambytes": elif name == "buyrambytes":
assert action.buy_ram_bytes is not None # _check_action assert action.buy_ram_bytes is not None # _check_action
await layout.confirm_action_buyrambytes(ctx, action.buy_ram_bytes) await layout.confirm_action_buyrambytes(action.buy_ram_bytes)
writers.write_action_buyrambytes(w, action.buy_ram_bytes) writers.write_action_buyrambytes(w, action.buy_ram_bytes)
elif name == "sellram": elif name == "sellram":
assert action.sell_ram is not None # _check_action assert action.sell_ram is not None # _check_action
await layout.confirm_action_sellram(ctx, action.sell_ram) await layout.confirm_action_sellram(action.sell_ram)
writers.write_action_sellram(w, action.sell_ram) writers.write_action_sellram(w, action.sell_ram)
elif name == "delegatebw": elif name == "delegatebw":
assert action.delegate is not None # _check_action assert action.delegate is not None # _check_action
await layout.confirm_action_delegate(ctx, action.delegate) await layout.confirm_action_delegate(action.delegate)
writers.write_action_delegate(w, action.delegate) writers.write_action_delegate(w, action.delegate)
elif name == "undelegatebw": elif name == "undelegatebw":
assert action.undelegate is not None # _check_action assert action.undelegate is not None # _check_action
await layout.confirm_action_undelegate(ctx, action.undelegate) await layout.confirm_action_undelegate(action.undelegate)
writers.write_action_undelegate(w, action.undelegate) writers.write_action_undelegate(w, action.undelegate)
elif name == "refund": elif name == "refund":
assert action.refund is not None # _check_action assert action.refund is not None # _check_action
await layout.confirm_action_refund(ctx, action.refund) await layout.confirm_action_refund(action.refund)
writers.write_action_refund(w, action.refund) writers.write_action_refund(w, action.refund)
elif name == "voteproducer": elif name == "voteproducer":
assert action.vote_producer is not None # _check_action assert action.vote_producer is not None # _check_action
await layout.confirm_action_voteproducer(ctx, action.vote_producer) await layout.confirm_action_voteproducer(action.vote_producer)
writers.write_action_voteproducer(w, action.vote_producer) writers.write_action_voteproducer(w, action.vote_producer)
elif name == "updateauth": elif name == "updateauth":
assert action.update_auth is not None # _check_action assert action.update_auth is not None # _check_action
await layout.confirm_action_updateauth(ctx, action.update_auth) await layout.confirm_action_updateauth(action.update_auth)
writers.write_action_updateauth(w, action.update_auth) writers.write_action_updateauth(w, action.update_auth)
elif name == "deleteauth": elif name == "deleteauth":
assert action.delete_auth is not None # _check_action assert action.delete_auth is not None # _check_action
await layout.confirm_action_deleteauth(ctx, action.delete_auth) await layout.confirm_action_deleteauth(action.delete_auth)
writers.write_action_deleteauth(w, action.delete_auth) writers.write_action_deleteauth(w, action.delete_auth)
elif name == "linkauth": elif name == "linkauth":
assert action.link_auth is not None # _check_action assert action.link_auth is not None # _check_action
await layout.confirm_action_linkauth(ctx, action.link_auth) await layout.confirm_action_linkauth(action.link_auth)
writers.write_action_linkauth(w, action.link_auth) writers.write_action_linkauth(w, action.link_auth)
elif name == "unlinkauth": elif name == "unlinkauth":
assert action.unlink_auth is not None # _check_action assert action.unlink_auth is not None # _check_action
await layout.confirm_action_unlinkauth(ctx, action.unlink_auth) await layout.confirm_action_unlinkauth(action.unlink_auth)
writers.write_action_unlinkauth(w, action.unlink_auth) writers.write_action_unlinkauth(w, action.unlink_auth)
elif name == "newaccount": elif name == "newaccount":
assert action.new_account is not None # _check_action assert action.new_account is not None # _check_action
await layout.confirm_action_newaccount(ctx, action.new_account) await layout.confirm_action_newaccount(action.new_account)
writers.write_action_newaccount(w, action.new_account) writers.write_action_newaccount(w, action.new_account)
else: else:
raise ValueError("Unrecognized action type for eosio") raise ValueError("Unrecognized action type for eosio")
elif name == "transfer": elif name == "transfer":
assert action.transfer is not None # _check_action assert action.transfer is not None # _check_action
await layout.confirm_action_transfer(ctx, action.transfer, account) await layout.confirm_action_transfer(action.transfer, account)
writers.write_action_transfer(w, action.transfer) writers.write_action_transfer(w, action.transfer)
else: else:
await _process_unknown_action(ctx, w, action) await _process_unknown_action(w, action)
writers.write_action_common(sha, action.common) writers.write_action_common(sha, action.common)
writers.write_bytes_prefixed(sha, w) writers.write_bytes_prefixed(sha, w)
async def _process_unknown_action( async def _process_unknown_action(w: Writer, action: EosTxActionAck) -> None:
ctx: wire.Context, w: Writer, action: EosTxActionAck
) -> None:
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.messages import EosTxActionAck, EosTxActionRequest from trezor.messages import EosTxActionAck, EosTxActionRequest
from trezor.wire.context import call
from .. import writers from .. import writers
from . import layout from . import layout
@ -106,9 +104,7 @@ async def _process_unknown_action(
bytes_left = unknown.data_size - len(data_chunk) bytes_left = unknown.data_size - len(data_chunk)
while bytes_left != 0: while bytes_left != 0:
action = await ctx.call( action = await call(EosTxActionRequest(data_size=bytes_left), EosTxActionAck)
EosTxActionRequest(data_size=bytes_left), EosTxActionAck
)
if unknown is None: if unknown is None:
raise ValueError("Bad response. Unknown struct expected.") raise ValueError("Bad response. Unknown struct expected.")
@ -120,7 +116,7 @@ async def _process_unknown_action(
if bytes_left < 0: if bytes_left < 0:
raise ValueError("Bad response. Buffer overflow.") raise ValueError("Bad response. Buffer overflow.")
await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest()) await layout.confirm_action_unknown(action.common, checksum.get_digest())
def _check_action(action: EosTxActionAck, name: str, account: str) -> bool: def _check_action(action: EosTxActionAck, name: str, account: str) -> bool:

@ -7,7 +7,6 @@ from ..helpers import eos_asset_to_string, eos_name_to_string
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterable from typing import Iterable
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
EosActionBuyRam, EosActionBuyRam,
EosActionBuyRamBytes, EosActionBuyRamBytes,
@ -35,13 +34,11 @@ is_last = False
# Because icon and br_code are almost always the same # Because icon and br_code are almost always the same
# (and also calling with positional arguments takes less space) # (and also calling with positional arguments takes less space)
async def _confirm_properties( async def _confirm_properties(
ctx: Context,
br_type: str, br_type: str,
title: str, title: str,
props: Iterable[PropertyType], props: Iterable[PropertyType],
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx,
br_type, br_type,
title, title,
props, props,
@ -50,9 +47,8 @@ async def _confirm_properties(
) )
async def confirm_action_buyram(ctx: Context, msg: EosActionBuyRam) -> None: async def confirm_action_buyram(msg: EosActionBuyRam) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_buyram", "confirm_buyram",
"Buy RAM", "Buy RAM",
( (
@ -63,9 +59,8 @@ async def confirm_action_buyram(ctx: Context, msg: EosActionBuyRam) -> None:
) )
async def confirm_action_buyrambytes(ctx: Context, msg: EosActionBuyRamBytes) -> None: async def confirm_action_buyrambytes(msg: EosActionBuyRamBytes) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_buyrambytes", "confirm_buyrambytes",
"Buy RAM", "Buy RAM",
( (
@ -76,7 +71,7 @@ async def confirm_action_buyrambytes(ctx: Context, msg: EosActionBuyRamBytes) ->
) )
async def confirm_action_delegate(ctx: Context, msg: EosActionDelegate) -> None: async def confirm_action_delegate(msg: EosActionDelegate) -> None:
props = [ props = [
("Sender:", eos_name_to_string(msg.sender)), ("Sender:", eos_name_to_string(msg.sender)),
("Receiver:", eos_name_to_string(msg.receiver)), ("Receiver:", eos_name_to_string(msg.receiver)),
@ -91,16 +86,14 @@ async def confirm_action_delegate(ctx: Context, msg: EosActionDelegate) -> None:
append(("Transfer:", "No")) append(("Transfer:", "No"))
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_delegate", "confirm_delegate",
"Delegate", "Delegate",
props, props,
) )
async def confirm_action_sellram(ctx: Context, msg: EosActionSellRam) -> None: async def confirm_action_sellram(msg: EosActionSellRam) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_sellram", "confirm_sellram",
"Sell RAM", "Sell RAM",
( (
@ -110,9 +103,8 @@ async def confirm_action_sellram(ctx: Context, msg: EosActionSellRam) -> None:
) )
async def confirm_action_undelegate(ctx: Context, msg: EosActionUndelegate) -> None: async def confirm_action_undelegate(msg: EosActionUndelegate) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_undelegate", "confirm_undelegate",
"Undelegate", "Undelegate",
( (
@ -124,22 +116,20 @@ async def confirm_action_undelegate(ctx: Context, msg: EosActionUndelegate) -> N
) )
async def confirm_action_refund(ctx: Context, msg: EosActionRefund) -> None: async def confirm_action_refund(msg: EosActionRefund) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_refund", "confirm_refund",
"Refund", "Refund",
(("Owner:", eos_name_to_string(msg.owner)),), (("Owner:", eos_name_to_string(msg.owner)),),
) )
async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer) -> None: async def confirm_action_voteproducer(msg: EosActionVoteProducer) -> None:
producers = msg.producers # local_cache_attribute producers = msg.producers # local_cache_attribute
if msg.proxy and not producers: if msg.proxy and not producers:
# PROXY # PROXY
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_voteproducer", "confirm_voteproducer",
"Vote for proxy", "Vote for proxy",
( (
@ -151,7 +141,6 @@ async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer)
elif producers: elif producers:
# PRODUCERS # PRODUCERS
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_voteproducer", "confirm_voteproducer",
"Vote for producers", "Vote for producers",
( (
@ -163,16 +152,13 @@ async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer)
else: else:
# Cancel vote # Cancel vote
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_voteproducer", "confirm_voteproducer",
"Cancel vote", "Cancel vote",
(("Voter:", eos_name_to_string(msg.voter)),), (("Voter:", eos_name_to_string(msg.voter)),),
) )
async def confirm_action_transfer( async def confirm_action_transfer(msg: EosActionTransfer, account: str) -> None:
ctx: Context, msg: EosActionTransfer, account: str
) -> None:
props = [ props = [
("From:", eos_name_to_string(msg.sender)), ("From:", eos_name_to_string(msg.sender)),
("To:", eos_name_to_string(msg.receiver)), ("To:", eos_name_to_string(msg.receiver)),
@ -182,14 +168,13 @@ async def confirm_action_transfer(
if msg.memo is not None: if msg.memo is not None:
props.append(("Memo", msg.memo[:512])) props.append(("Memo", msg.memo[:512]))
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_transfer", "confirm_transfer",
"Transfer", "Transfer",
props, props,
) )
async def confirm_action_updateauth(ctx: Context, msg: EosActionUpdateAuth) -> None: async def confirm_action_updateauth(msg: EosActionUpdateAuth) -> None:
props: list[PropertyType] = [ props: list[PropertyType] = [
("Account:", eos_name_to_string(msg.account)), ("Account:", eos_name_to_string(msg.account)),
("Permission:", eos_name_to_string(msg.permission)), ("Permission:", eos_name_to_string(msg.permission)),
@ -197,16 +182,14 @@ async def confirm_action_updateauth(ctx: Context, msg: EosActionUpdateAuth) -> N
] ]
props.extend(authorization_fields(msg.auth)) props.extend(authorization_fields(msg.auth))
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_updateauth", "confirm_updateauth",
"Update Auth", "Update Auth",
props, props,
) )
async def confirm_action_deleteauth(ctx: Context, msg: EosActionDeleteAuth) -> None: async def confirm_action_deleteauth(msg: EosActionDeleteAuth) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_deleteauth", "confirm_deleteauth",
"Delete Auth", "Delete Auth",
( (
@ -216,9 +199,8 @@ async def confirm_action_deleteauth(ctx: Context, msg: EosActionDeleteAuth) -> N
) )
async def confirm_action_linkauth(ctx: Context, msg: EosActionLinkAuth) -> None: async def confirm_action_linkauth(msg: EosActionLinkAuth) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_linkauth", "confirm_linkauth",
"Link Auth", "Link Auth",
( (
@ -230,9 +212,8 @@ async def confirm_action_linkauth(ctx: Context, msg: EosActionLinkAuth) -> None:
) )
async def confirm_action_unlinkauth(ctx: Context, msg: EosActionUnlinkAuth) -> None: async def confirm_action_unlinkauth(msg: EosActionUnlinkAuth) -> None:
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_unlinkauth", "confirm_unlinkauth",
"Unlink Auth", "Unlink Auth",
( (
@ -243,7 +224,7 @@ async def confirm_action_unlinkauth(ctx: Context, msg: EosActionUnlinkAuth) -> N
) )
async def confirm_action_newaccount(ctx: Context, msg: EosActionNewAccount) -> None: async def confirm_action_newaccount(msg: EosActionNewAccount) -> None:
props: list[PropertyType] = [ props: list[PropertyType] = [
("Creator:", eos_name_to_string(msg.creator)), ("Creator:", eos_name_to_string(msg.creator)),
("Name:", eos_name_to_string(msg.name)), ("Name:", eos_name_to_string(msg.name)),
@ -251,18 +232,14 @@ async def confirm_action_newaccount(ctx: Context, msg: EosActionNewAccount) -> N
props.extend(authorization_fields(msg.owner)) props.extend(authorization_fields(msg.owner))
props.extend(authorization_fields(msg.active)) props.extend(authorization_fields(msg.active))
await _confirm_properties( await _confirm_properties(
ctx,
"confirm_newaccount", "confirm_newaccount",
"New Account", "New Account",
props, props,
) )
async def confirm_action_unknown( async def confirm_action_unknown(action: EosActionCommon, checksum: bytes) -> None:
ctx: Context, action: EosActionCommon, checksum: bytes
) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_unknown", "confirm_unknown",
"Arbitrary data", "Arbitrary data",
( (

@ -5,20 +5,17 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EosGetPublicKey, EosPublicKey from trezor.messages import EosGetPublicKey, EosPublicKey
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_public_key( async def get_public_key(msg: EosGetPublicKey, keychain: Keychain) -> EosPublicKey:
ctx: Context, msg: EosGetPublicKey, keychain: Keychain
) -> EosPublicKey:
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.messages import EosPublicKey from trezor.messages import EosPublicKey
from apps.common import paths from apps.common import paths
from .helpers import public_key_to_wif from .helpers import public_key_to_wif
from .layout import require_get_public_key from .layout import require_get_public_key
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
@ -26,5 +23,5 @@ async def get_public_key(
wif = public_key_to_wif(public_key) wif = public_key_to_wif(public_key)
if msg.show_display: if msg.show_display:
await require_get_public_key(ctx, wif) await require_get_public_key(wif)
return EosPublicKey(wif_public_key=wif, raw_public_key=public_key) return EosPublicKey(wif_public_key=wif, raw_public_key=public_key)

@ -1,22 +1,15 @@
from typing import TYPE_CHECKING async def require_get_public_key(public_key: str) -> None:
if TYPE_CHECKING:
from trezor.wire import Context
async def require_get_public_key(ctx: Context, public_key: str) -> None:
from trezor.ui.layouts import show_pubkey from trezor.ui.layouts import show_pubkey
await show_pubkey(ctx, public_key) await show_pubkey(public_key)
async def require_sign_tx(ctx: Context, num_actions: int) -> None: async def require_sign_tx(num_actions: int) -> None:
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.strings import format_plural from trezor.strings import format_plural
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
await confirm_action( await confirm_action(
ctx,
"confirm_tx", "confirm_tx",
"Sign transaction", "Sign transaction",
description="You are about to sign {}.", description="You are about to sign {}.",

@ -3,14 +3,14 @@ from typing import TYPE_CHECKING
from apps.common.keychain import auto_keychain from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import EosSignTx, EosSignedTx from trezor.messages import EosSignTx, EosSignedTx
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSignedTx: async def sign_tx(msg: EosSignTx, keychain: Keychain) -> EosSignedTx:
from trezor.wire import DataError from trezor.wire import DataError
from trezor.wire.context import call
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages import EosSignedTx, EosTxActionAck, EosTxActionRequest from trezor.messages import EosSignedTx, EosTxActionAck, EosTxActionRequest
@ -26,7 +26,7 @@ async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSigned
if not num_actions: if not num_actions:
raise DataError("No actions") raise DataError("No actions")
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
sha = HashWriter(sha256()) sha = HashWriter(sha256())
@ -36,13 +36,13 @@ async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSigned
write_header(sha, msg.header) write_header(sha, msg.header)
write_uvarint(sha, 0) write_uvarint(sha, 0)
write_uvarint(sha, num_actions) write_uvarint(sha, num_actions)
await require_sign_tx(ctx, num_actions) await require_sign_tx(num_actions)
# actions # actions
for index in range(num_actions): for index in range(num_actions):
action = await ctx.call(EosTxActionRequest(), EosTxActionAck) action = await call(EosTxActionRequest(), EosTxActionAck)
is_last = index == num_actions - 1 is_last = index == num_actions - 1
await process_action(ctx, sha, action, is_last) await process_action(sha, action, is_last)
write_uvarint(sha, 0) write_uvarint(sha, 0)
write_bytes_fixed(sha, bytearray(32), 32) write_bytes_fixed(sha, bytearray(32), 32)

@ -4,7 +4,6 @@ from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress, EthereumAddress from trezor.messages import EthereumGetAddress, EthereumAddress
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .definitions import Definitions from .definitions import Definitions
@ -12,7 +11,6 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def get_address( async def get_address(
ctx: Context,
msg: EthereumGetAddress, msg: EthereumGetAddress,
keychain: Keychain, keychain: Keychain,
defs: Definitions, defs: Definitions,
@ -24,13 +22,13 @@ async def get_address(
address_n = msg.address_n # local_cache_attribute address_n = msg.address_n # local_cache_attribute
await paths.validate_path(ctx, keychain, address_n) await paths.validate_path(keychain, address_n)
node = keychain.derive(address_n) node = keychain.derive(address_n)
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network) address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
if msg.show_display: if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n)) await show_address(address, path=paths.address_n_to_str(address_n))
return EthereumAddress(address=address) return EthereumAddress(address=address)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumGetPublicKey, EthereumPublicKey from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
from trezor.wire import Context
async def get_public_key(ctx: Context, msg: EthereumGetPublicKey) -> EthereumPublicKey: async def get_public_key(msg: EthereumGetPublicKey) -> EthereumPublicKey:
from ubinascii import hexlify from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, GetPublicKey from trezor.messages import EthereumPublicKey, GetPublicKey
from trezor.ui.layouts import show_pubkey from trezor.ui.layouts import show_pubkey
@ -13,9 +12,9 @@ async def get_public_key(ctx: Context, msg: EthereumGetPublicKey) -> EthereumPub
# we use the Bitcoin format for Ethereum xpubs # we use the Bitcoin format for Ethereum xpubs
btc_pubkey_msg = GetPublicKey(address_n=msg.address_n) btc_pubkey_msg = GetPublicKey(address_n=msg.address_n)
resp = await bitcoin_get_public_key.get_public_key(ctx, btc_pubkey_msg) resp = await bitcoin_get_public_key.get_public_key(btc_pubkey_msg)
if msg.show_display: if msg.show_display:
await show_pubkey(ctx, hexlify(resp.node.public_key).decode()) await show_pubkey(hexlify(resp.node.public_key).decode())
return EthereumPublicKey(node=resp.node, xpub=resp.xpub) return EthereumPublicKey(node=resp.node, xpub=resp.xpub)

@ -12,8 +12,6 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
EthereumGetAddress, EthereumGetAddress,
EthereumSignMessage, EthereumSignMessage,
@ -36,7 +34,7 @@ if TYPE_CHECKING:
) )
HandlerAddressN = Callable[ HandlerAddressN = Callable[
[Context, MsgInAddressN, Keychain, definitions.Definitions], [MsgInAddressN, Keychain, definitions.Definitions],
Awaitable[MsgOut], Awaitable[MsgOut],
] ]
@ -48,7 +46,7 @@ if TYPE_CHECKING:
) )
HandlerChainId = Callable[ HandlerChainId = Callable[
[Context, MsgInSignTx, Keychain, definitions.Definitions], [MsgInSignTx, Keychain, definitions.Definitions],
Awaitable[MsgOut], Awaitable[MsgOut],
] ]
@ -122,13 +120,13 @@ def with_keychain_from_path(
def decorator( def decorator(
func: HandlerAddressN[MsgInAddressN, MsgOut] func: HandlerAddressN[MsgInAddressN, MsgOut]
) -> Handler[MsgInAddressN, MsgOut]: ) -> Handler[MsgInAddressN, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInAddressN) -> MsgOut: async def wrapper(msg: MsgInAddressN) -> MsgOut:
slip44 = _slip44_from_address_n(msg.address_n) slip44 = _slip44_from_address_n(msg.address_n)
defs = _defs_from_message(msg, slip44=slip44) defs = _defs_from_message(msg, slip44=slip44)
schemas = _schemas_from_network(patterns, defs.network) schemas = _schemas_from_network(patterns, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(CURVE, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain, defs) return await func(msg, keychain, defs)
return wrapper return wrapper
@ -139,11 +137,11 @@ def with_keychain_from_chain_id(
func: HandlerChainId[MsgInSignTx, MsgOut] func: HandlerChainId[MsgInSignTx, MsgOut]
) -> Handler[MsgInSignTx, MsgOut]: ) -> Handler[MsgInSignTx, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed # this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInSignTx) -> MsgOut: async def wrapper(msg: MsgInSignTx) -> MsgOut:
defs = _defs_from_message(msg, chain_id=msg.chain_id) defs = _defs_from_message(msg, chain_id=msg.chain_id)
schemas = _schemas_from_network(PATTERNS_ADDRESS, defs.network) schemas = _schemas_from_network(PATTERNS_ADDRESS, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(CURVE, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain, defs) return await func(msg, keychain, defs)
return wrapper return wrapper

@ -22,11 +22,9 @@ if TYPE_CHECKING:
EthereumStructMember, EthereumStructMember,
EthereumTokenInfo, EthereumTokenInfo,
) )
from trezor.wire import Context
def require_confirm_tx( def require_confirm_tx(
ctx: Context,
to_bytes: bytes, to_bytes: bytes,
value: int, value: int,
network: EthereumNetworkInfo, network: EthereumNetworkInfo,
@ -40,7 +38,6 @@ def require_confirm_tx(
else: else:
to_str = "new contract?" to_str = "new contract?"
return confirm_output( return confirm_output(
ctx,
to_str, to_str,
format_ethereum_amount(value, token, network), format_ethereum_amount(value, token, network),
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
@ -48,7 +45,6 @@ def require_confirm_tx(
async def require_confirm_fee( async def require_confirm_fee(
ctx: Context,
spending: int, spending: int,
gas_price: int, gas_price: int,
gas_limit: int, gas_limit: int,
@ -56,13 +52,11 @@ async def require_confirm_fee(
token: EthereumTokenInfo | None, token: EthereumTokenInfo | None,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx,
title="Confirm fee", title="Confirm fee",
description="Gas price:", description="Gas price:",
amount=format_ethereum_amount(gas_price, None, network), amount=format_ethereum_amount(gas_price, None, network),
) )
await confirm_total( await confirm_total(
ctx,
total_amount=format_ethereum_amount(spending, token, network), total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network), fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
@ -71,7 +65,6 @@ async def require_confirm_fee(
async def require_confirm_eip1559_fee( async def require_confirm_eip1559_fee(
ctx: Context,
spending: int, spending: int,
max_priority_fee: int, max_priority_fee: int,
max_gas_fee: int, max_gas_fee: int,
@ -80,19 +73,16 @@ async def require_confirm_eip1559_fee(
token: EthereumTokenInfo | None, token: EthereumTokenInfo | None,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_gas_fee, None, network), format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas", "Maximum fee per gas",
) )
await confirm_amount( await confirm_amount(
ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_priority_fee, None, network), format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas", "Priority fee per gas",
) )
await confirm_total( await confirm_total(
ctx,
format_ethereum_amount(spending, token, network), format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, network), format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
@ -100,15 +90,12 @@ async def require_confirm_eip1559_fee(
) )
def require_confirm_unknown_token( def require_confirm_unknown_token(address_bytes: bytes) -> Awaitable[None]:
ctx: Context, address_bytes: bytes
) -> Awaitable[None]:
from ubinascii import hexlify from ubinascii import hexlify
from trezor.ui.layouts import confirm_address from trezor.ui.layouts import confirm_address
contract_address_hex = "0x" + hexlify(address_bytes).decode() contract_address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address( return confirm_address(
ctx,
"Unknown token", "Unknown token",
contract_address_hex, contract_address_hex,
"Contract:", "Contract:",
@ -117,22 +104,20 @@ def require_confirm_unknown_token(
) )
def require_confirm_address(ctx: Context, address_bytes: bytes) -> Awaitable[None]: def require_confirm_address(address_bytes: bytes) -> Awaitable[None]:
from ubinascii import hexlify from ubinascii import hexlify
from trezor.ui.layouts import confirm_address from trezor.ui.layouts import confirm_address
address_hex = "0x" + hexlify(address_bytes).decode() address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address( return confirm_address(
ctx,
"Signing address", "Signing address",
address_hex, address_hex,
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
) )
def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitable[None]: def require_confirm_data(data: bytes, data_total: int) -> Awaitable[None]:
return confirm_blob( return confirm_blob(
ctx,
"confirm_data", "confirm_data",
"Confirm data", "Confirm data",
data, data,
@ -142,11 +127,10 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
) )
async def confirm_typed_data_final(ctx: Context) -> None: async def confirm_typed_data_final() -> None:
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
await confirm_action( await confirm_action(
ctx,
"confirm_typed_data_final", "confirm_typed_data_final",
"Confirm typed data", "Confirm typed data",
"Really sign EIP-712 typed data?", "Really sign EIP-712 typed data?",
@ -155,9 +139,8 @@ async def confirm_typed_data_final(ctx: Context) -> None:
) )
def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]: def confirm_empty_typed_message() -> Awaitable[None]:
return confirm_text( return confirm_text(
ctx,
"confirm_empty_typed_message", "confirm_empty_typed_message",
"Confirm message", "Confirm message",
"", "",
@ -165,7 +148,7 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
) )
async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool: async def should_show_domain(name: bytes, version: bytes) -> bool:
domain_name = decode_typed_data(name, "string") domain_name = decode_typed_data(name, "string")
domain_version = decode_typed_data(version, "string") domain_version = decode_typed_data(version, "string")
@ -175,7 +158,6 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
(ui.DEMIBOLD, domain_version), (ui.DEMIBOLD, domain_version),
) )
return await should_show_more( return await should_show_more(
ctx,
"Confirm domain", "Confirm domain",
para, para,
"Show full domain", "Show full domain",
@ -184,7 +166,6 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
async def should_show_struct( async def should_show_struct(
ctx: Context,
description: str, description: str,
data_members: list[EthereumStructMember], data_members: list[EthereumStructMember],
title: str = "Confirm struct", title: str = "Confirm struct",
@ -199,7 +180,6 @@ async def should_show_struct(
(ui.NORMAL, ", ".join(field.name for field in data_members)), (ui.NORMAL, ", ".join(field.name for field in data_members)),
) )
return await should_show_more( return await should_show_more(
ctx,
title, title,
para, para,
button_text, button_text,
@ -208,14 +188,12 @@ async def should_show_struct(
async def should_show_array( async def should_show_array(
ctx: Context,
parent_objects: Iterable[str], parent_objects: Iterable[str],
data_type: str, data_type: str,
size: int, size: int,
) -> bool: ) -> bool:
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),) para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
return await should_show_more( return await should_show_more(
ctx,
limit_str(".".join(parent_objects)), limit_str(".".join(parent_objects)),
para, para,
"Show full array", "Show full array",
@ -224,7 +202,6 @@ async def should_show_array(
async def confirm_typed_value( async def confirm_typed_value(
ctx: Context,
name: str, name: str,
value: bytes, value: bytes,
parent_objects: list[str], parent_objects: list[str],
@ -247,7 +224,6 @@ async def confirm_typed_value(
if field.data_type in (EthereumDataType.ADDRESS, EthereumDataType.BYTES): if field.data_type in (EthereumDataType.ADDRESS, EthereumDataType.BYTES):
await confirm_blob( await confirm_blob(
ctx,
"confirm_typed_value", "confirm_typed_value",
title, title,
data, data,
@ -256,7 +232,6 @@ async def confirm_typed_value(
) )
else: else:
await confirm_text( await confirm_text(
ctx,
"confirm_typed_value", "confirm_typed_value",
title, title,
data, data,

@ -7,7 +7,6 @@ if TYPE_CHECKING:
EthereumSignMessage, EthereumSignMessage,
EthereumMessageSignature, EthereumMessageSignature,
) )
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .definitions import Definitions from .definitions import Definitions
@ -27,7 +26,6 @@ def message_digest(message: bytes) -> bytes:
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_message( async def sign_message(
ctx: Context,
msg: EthereumSignMessage, msg: EthereumSignMessage,
keychain: Keychain, keychain: Keychain,
defs: Definitions, defs: Definitions,
@ -41,13 +39,11 @@ async def sign_message(
from .helpers import address_from_bytes from .helpers import address_from_bytes
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network) address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify( await confirm_signverify("ETH", decode_message(msg.message), address, verify=False)
ctx, "ETH", decode_message(msg.message), address, verify=False
)
signature = secp256k1.sign( signature = secp256k1.sign(
node.private_key(), node.private_key(),

@ -10,7 +10,6 @@ from .keychain import with_keychain_from_chain_id
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo
from trezor.wire import Context
from .definitions import Definitions from .definitions import Definitions
from .keychain import MsgInSignTx from .keychain import MsgInSignTx
@ -24,7 +23,6 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx( async def sign_tx(
ctx: Context,
msg: EthereumSignTx, msg: EthereumSignTx,
keychain: Keychain, keychain: Keychain,
defs: Definitions, defs: Definitions,
@ -45,19 +43,18 @@ async def sign_tx(
raise DataError("Fee overflow") raise DataError("Fee overflow")
check_common_fields(msg) check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs) token, address_bytes, recipient, value = await handle_erc20(msg, defs)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, defs.network, token) await require_confirm_tx(recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(msg.data_initial_chunk, data_total)
await require_confirm_fee( await require_confirm_fee(
ctx,
value, value,
int.from_bytes(msg.gas_price, "big"), int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"), int.from_bytes(msg.gas_limit, "big"),
@ -87,7 +84,7 @@ async def sign_tx(
sha.extend(data) sha.extend(data)
while data_left > 0: while data_left > 0:
resp = await send_request_chunk(ctx, data_left) resp = await send_request_chunk(data_left)
data_left -= len(resp.data_chunk) data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk) sha.extend(resp.data_chunk)
@ -103,7 +100,6 @@ async def sign_tx(
async def handle_erc20( async def handle_erc20(
ctx: Context,
msg: MsgInSignTx, msg: MsgInSignTx,
definitions: Definitions, definitions: Definitions,
) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]: ) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]:
@ -127,7 +123,7 @@ async def handle_erc20(
value = int.from_bytes(data_initial_chunk[36:68], "big") value = int.from_bytes(data_initial_chunk[36:68], "big")
if token is tokens.UNKNOWN_TOKEN: if token is tokens.UNKNOWN_TOKEN:
await require_confirm_unknown_token(ctx, address_bytes) await require_confirm_unknown_token(address_bytes)
return token, address_bytes, recipient, value return token, address_bytes, recipient, value
@ -157,13 +153,14 @@ def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
return length return length
async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck: async def send_request_chunk(data_left: int) -> EthereumTxAck:
from trezor.messages import EthereumTxAck from trezor.messages import EthereumTxAck
from trezor.wire.context import call
# TODO: layoutProgress ? # TODO: layoutProgress ?
req = EthereumTxRequest() req = EthereumTxRequest()
req.data_length = min(data_left, 1024) req.data_length = min(data_left, 1024)
return await ctx.call(req, EthereumTxAck) return await call(req, EthereumTxAck)
def _sign_digest( def _sign_digest(

@ -12,7 +12,6 @@ if TYPE_CHECKING:
EthereumAccessList, EthereumAccessList,
EthereumTxRequest, EthereumTxRequest,
) )
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .definitions import Definitions from .definitions import Definitions
@ -30,7 +29,6 @@ def access_list_item_length(item: EthereumAccessList) -> int:
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx_eip1559( async def sign_tx_eip1559(
ctx: Context,
msg: EthereumSignTxEIP1559, msg: EthereumSignTxEIP1559,
keychain: Keychain, keychain: Keychain,
defs: Definitions, defs: Definitions,
@ -56,19 +54,18 @@ async def sign_tx_eip1559(
raise wire.DataError("Fee overflow") raise wire.DataError("Fee overflow")
check_common_fields(msg) check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs) token, address_bytes, recipient, value = await handle_erc20(msg, defs)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, defs.network, token) await require_confirm_tx(recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(msg.data_initial_chunk, data_total)
await require_confirm_eip1559_fee( await require_confirm_eip1559_fee(
ctx,
value, value,
int.from_bytes(msg.max_priority_fee, "big"), int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"), int.from_bytes(msg.max_gas_fee, "big"),
@ -108,7 +105,7 @@ async def sign_tx_eip1559(
sha.extend(data) sha.extend(data)
while data_left > 0: while data_left > 0:
resp = await send_request_chunk(ctx, data_left) resp = await send_request_chunk(data_left)
data_left -= len(resp.data_chunk) data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk) sha.extend(resp.data_chunk)

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING
from trezor.enums import EthereumDataType from trezor.enums import EthereumDataType
from trezor.wire import DataError from trezor.wire import DataError
from trezor.wire.context import call
from .helpers import get_type_name from .helpers import get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
@ -9,7 +10,6 @@ from .layout import should_show_struct
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.utils import HashWriter from trezor.utils import HashWriter
from .definitions import Definitions from .definitions import Definitions
@ -23,7 +23,6 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_typed_data( async def sign_typed_data(
ctx: Context,
msg: EthereumSignTypedData, msg: EthereumSignTypedData,
keychain: Keychain, keychain: Keychain,
defs: Definitions, defs: Definitions,
@ -34,16 +33,16 @@ async def sign_typed_data(
from .layout import require_confirm_address from .layout import require_confirm_address
from trezor.messages import EthereumTypedDataSignature from trezor.messages import EthereumTypedDataSignature
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address_bytes: bytes = node.ethereum_pubkeyhash() address_bytes: bytes = node.ethereum_pubkeyhash()
# Display address so user can validate it # Display address so user can validate it
await require_confirm_address(ctx, address_bytes) await require_confirm_address(address_bytes)
data_hash = await _generate_typed_data_hash( data_hash = await _generate_typed_data_hash(
ctx, msg.primary_type, msg.metamask_v4_compat msg.primary_type, msg.metamask_v4_compat
) )
signature = secp256k1.sign( signature = secp256k1.sign(
@ -57,7 +56,7 @@ async def sign_typed_data(
async def _generate_typed_data_hash( async def _generate_typed_data_hash(
ctx: Context, primary_type: str, metamask_v4_compat: bool = True primary_type: str, metamask_v4_compat: bool = True
) -> bytes: ) -> bytes:
""" """
Generate typed data hash according to EIP-712 specification Generate typed data hash according to EIP-712 specification
@ -72,14 +71,13 @@ async def _generate_typed_data_hash(
) )
typed_data_envelope = TypedDataEnvelope( typed_data_envelope = TypedDataEnvelope(
ctx,
primary_type, primary_type,
metamask_v4_compat, metamask_v4_compat,
) )
await typed_data_envelope.collect_types() await typed_data_envelope.collect_types()
name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope) name, version = await _get_name_and_version_for_domain(typed_data_envelope)
show_domain = await should_show_domain(ctx, name, version) show_domain = await should_show_domain(name, version)
domain_separator = await typed_data_envelope.hash_struct( domain_separator = await typed_data_envelope.hash_struct(
"EIP712Domain", "EIP712Domain",
[0], [0],
@ -91,11 +89,10 @@ async def _generate_typed_data_hash(
# In this case, we ignore the "message" part and only use the "domain" part # In this case, we ignore the "message" part and only use the "domain" part
# https://ethereum-magicians.org/t/eip-712-standards-clarification-primarytype-as-domaintype/3286 # https://ethereum-magicians.org/t/eip-712-standards-clarification-primarytype-as-domaintype/3286
if primary_type == "EIP712Domain": if primary_type == "EIP712Domain":
await confirm_empty_typed_message(ctx) await confirm_empty_typed_message()
message_hash = b"" message_hash = b""
else: else:
show_message = await should_show_struct( show_message = await should_show_struct(
ctx,
primary_type, primary_type,
typed_data_envelope.types[primary_type].members, typed_data_envelope.types[primary_type].members,
"Confirm message", "Confirm message",
@ -108,7 +105,7 @@ async def _generate_typed_data_hash(
[primary_type], [primary_type],
) )
await confirm_typed_data_final(ctx) await confirm_typed_data_final()
return keccak256(b"\x19\x01" + domain_separator + message_hash) return keccak256(b"\x19\x01" + domain_separator + message_hash)
@ -131,11 +128,9 @@ class TypedDataEnvelope:
def __init__( def __init__(
self, self,
ctx: Context,
primary_type: str, primary_type: str,
metamask_v4_compat: bool, metamask_v4_compat: bool,
) -> None: ) -> None:
self.ctx = ctx
self.primary_type = primary_type self.primary_type = primary_type
self.metamask_v4_compat = metamask_v4_compat self.metamask_v4_compat = metamask_v4_compat
self.types: dict[str, EthereumTypedDataStructAck] = {} self.types: dict[str, EthereumTypedDataStructAck] = {}
@ -153,7 +148,7 @@ class TypedDataEnvelope:
) )
req = EthereumTypedDataStructRequest(name=type_name) req = EthereumTypedDataStructRequest(name=type_name)
current_type = await self.ctx.call(req, EthereumTypedDataStructAck) current_type = await call(req, EthereumTypedDataStructAck)
self.types[type_name] = current_type self.types[type_name] = current_type
for member in current_type.members: for member in current_type.members:
member_type = member.type member_type = member.type
@ -254,8 +249,6 @@ class TypedDataEnvelope:
""" """
from .layout import confirm_typed_value, should_show_array from .layout import confirm_typed_value, should_show_array
ctx = self.ctx # local_cache_attribute
type_members = self.types[primary_type].members type_members = self.types[primary_type].members
member_value_path = member_path + [0] member_value_path = member_path + [0]
current_parent_objects = parent_objects + [""] current_parent_objects = parent_objects + [""]
@ -272,7 +265,6 @@ class TypedDataEnvelope:
if show_data: if show_data:
show_struct = await should_show_struct( show_struct = await should_show_struct(
ctx,
struct_name, # description struct_name, # description
self.types[struct_name].members, # data_members self.types[struct_name].members, # data_members
".".join(current_parent_objects), # title ".".join(current_parent_objects), # title
@ -290,7 +282,7 @@ class TypedDataEnvelope:
elif field_type.data_type == EthereumDataType.ARRAY: elif field_type.data_type == EthereumDataType.ARRAY:
# Getting the length of the array first, if not fixed # Getting the length of the array first, if not fixed
if field_type.size is None: if field_type.size is None:
array_size = await _get_array_size(ctx, member_value_path) array_size = await _get_array_size(member_value_path)
else: else:
array_size = field_type.size array_size = field_type.size
@ -300,7 +292,6 @@ class TypedDataEnvelope:
if show_data: if show_data:
show_array = await should_show_array( show_array = await should_show_array(
ctx,
current_parent_objects, current_parent_objects,
get_type_name(entry_type), get_type_name(entry_type),
array_size, array_size,
@ -338,11 +329,10 @@ class TypedDataEnvelope:
current_parent_objects, current_parent_objects,
) )
else: else:
value = await get_value(ctx, entry_type, el_member_path) value = await get_value(entry_type, el_member_path)
encode_field(arr_w, entry_type, value) encode_field(arr_w, entry_type, value)
if show_array: if show_array:
await confirm_typed_value( await confirm_typed_value(
ctx,
field_name, field_name,
value, value,
parent_objects, parent_objects,
@ -351,11 +341,10 @@ class TypedDataEnvelope:
) )
w.extend(arr_w.get_digest()) w.extend(arr_w.get_digest())
else: else:
value = await get_value(ctx, field_type, member_value_path) value = await get_value(field_type, member_value_path)
encode_field(w, field_type, value) encode_field(w, field_type, value)
if show_data: if show_data:
await confirm_typed_value( await confirm_typed_value(
ctx,
field_name, field_name,
value, value,
parent_objects, parent_objects,
@ -503,18 +492,17 @@ def validate_field_type(field: EthereumFieldType) -> None:
raise DataError("Unexpected size in str/bool/addr") raise DataError("Unexpected size in str/bool/addr")
async def _get_array_size(ctx: Context, member_path: list[int]) -> int: async def _get_array_size(member_path: list[int]) -> int:
"""Get the length of an array at specific `member_path` from the client.""" """Get the length of an array at specific `member_path` from the client."""
from trezor.messages import EthereumFieldType from trezor.messages import EthereumFieldType
# Field type for getting the array length from client, so we can check the return value # Field type for getting the array length from client, so we can check the return value
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2) ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path) length_value = await get_value(ARRAY_LENGTH_TYPE, member_path)
return int.from_bytes(length_value, "big") return int.from_bytes(length_value, "big")
async def get_value( async def get_value(
ctx: Context,
field: EthereumFieldType, field: EthereumFieldType,
member_value_path: list[int], member_value_path: list[int],
) -> bytes: ) -> bytes:
@ -524,7 +512,7 @@ async def get_value(
req = EthereumTypedDataValueRequest( req = EthereumTypedDataValueRequest(
member_path=member_value_path, member_path=member_value_path,
) )
res = await ctx.call(req, EthereumTypedDataValueAck) res = await call(req, EthereumTypedDataValueAck)
value = res.value value = res.value
_validate_value(field=field, value=value) _validate_value(field=field, value=value)
@ -533,7 +521,7 @@ async def get_value(
async def _get_name_and_version_for_domain( async def _get_name_and_version_for_domain(
ctx: Context, typed_data_envelope: TypedDataEnvelope typed_data_envelope: TypedDataEnvelope,
) -> tuple[bytes, bytes]: ) -> tuple[bytes, bytes]:
domain_name = b"unknown" domain_name = b"unknown"
domain_version = b"unknown" domain_version = b"unknown"
@ -543,8 +531,8 @@ async def _get_name_and_version_for_domain(
for member_index, member in enumerate(domain_members): for member_index, member in enumerate(domain_members):
member_value_path[-1] = member_index member_value_path[-1] = member_index
if member.name == "name": if member.name == "name":
domain_name = await get_value(ctx, member.type, member_value_path) domain_name = await get_value(member.type, member_value_path)
elif member.name == "version": elif member.name == "version":
domain_version = await get_value(ctx, member.type, member_value_path) domain_version = await get_value(member.type, member_value_path)
return domain_name, domain_version return domain_name, domain_version

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumVerifyMessage, Success from trezor.messages import EthereumVerifyMessage, Success
from trezor.wire import Context
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success: async def verify_message(msg: EthereumVerifyMessage) -> Success:
from trezor.wire import DataError from trezor.wire import DataError
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
@ -35,9 +34,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
address = address_from_bytes(address_bytes) address = address_from_bytes(address_bytes)
await confirm_signverify( await confirm_signverify("ETH", decode_message(msg.message), address, verify=True)
ctx, "ETH", decode_message(msg.message), address, verify=True
)
await show_success(ctx, "verify_message", "The signature is valid.") await show_success("verify_message", "The signature is valid.")
return Success(message="Message verified") return Success(message="Message verified")

@ -3,10 +3,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import Success from trezor.messages import Success
from trezor.messages import ApplyFlags from trezor.messages import ApplyFlags
from trezor.wire import GenericContext
async def apply_flags(ctx: GenericContext, msg: ApplyFlags) -> Success: async def apply_flags(msg: ApplyFlags) -> Success:
import storage.device import storage.device
from storage.device import set_flags from storage.device import set_flags
from trezor.wire import NotInitialized from trezor.wire import NotInitialized

@ -9,7 +9,6 @@ import trezorui2
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ApplySettings, Success from trezor.messages import ApplySettings, Success
from trezor.wire import Context, GenericContext
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
@ -62,7 +61,7 @@ def _validate_homescreen(homescreen: bytes) -> None:
_validate_homescreen_model_specific(homescreen) _validate_homescreen_model_specific(homescreen)
async def apply_settings(ctx: Context, msg: ApplySettings) -> Success: async def apply_settings(msg: ApplySettings) -> Success:
import storage.device as storage_device import storage.device as storage_device
from apps.common import safety_checks from apps.common import safety_checks
from trezor.messages import Success from trezor.messages import Success
@ -98,7 +97,7 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
if homescreen is not None: if homescreen is not None:
_validate_homescreen(homescreen) _validate_homescreen(homescreen)
await _require_confirm_change_homescreen(ctx, homescreen) await _require_confirm_change_homescreen(homescreen)
try: try:
storage_device.set_homescreen(homescreen) storage_device.set_homescreen(homescreen)
except ValueError: except ValueError:
@ -107,19 +106,17 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
if label is not None: if label is not None:
if len(label) > storage_device.LABEL_MAXLENGTH: if len(label) > storage_device.LABEL_MAXLENGTH:
raise DataError("Label too long") raise DataError("Label too long")
await _require_confirm_change_label(ctx, label) await _require_confirm_change_label(label)
storage_device.set_label(label) storage_device.set_label(label)
if use_passphrase is not None: if use_passphrase is not None:
await _require_confirm_change_passphrase(ctx, use_passphrase) await _require_confirm_change_passphrase(use_passphrase)
storage_device.set_passphrase_enabled(use_passphrase) storage_device.set_passphrase_enabled(use_passphrase)
if passphrase_always_on_device is not None: if passphrase_always_on_device is not None:
if not storage_device.is_passphrase_enabled(): if not storage_device.is_passphrase_enabled():
raise DataError("Passphrase is not enabled") raise DataError("Passphrase is not enabled")
await _require_confirm_change_passphrase_source( await _require_confirm_change_passphrase_source(passphrase_always_on_device)
ctx, passphrase_always_on_device
)
storage_device.set_passphrase_always_on_device(passphrase_always_on_device) storage_device.set_passphrase_always_on_device(passphrase_always_on_device)
if auto_lock_delay_ms is not None: if auto_lock_delay_ms is not None:
@ -127,25 +124,25 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
raise ProcessError("Auto-lock delay too short") raise ProcessError("Auto-lock delay too short")
if auto_lock_delay_ms > storage_device.AUTOLOCK_DELAY_MAXIMUM: if auto_lock_delay_ms > storage_device.AUTOLOCK_DELAY_MAXIMUM:
raise ProcessError("Auto-lock delay too long") raise ProcessError("Auto-lock delay too long")
await _require_confirm_change_autolock_delay(ctx, auto_lock_delay_ms) await _require_confirm_change_autolock_delay(auto_lock_delay_ms)
storage_device.set_autolock_delay_ms(auto_lock_delay_ms) storage_device.set_autolock_delay_ms(auto_lock_delay_ms)
if msg_safety_checks is not None: if msg_safety_checks is not None:
await _require_confirm_safety_checks(ctx, msg_safety_checks) await _require_confirm_safety_checks(msg_safety_checks)
safety_checks.apply_setting(msg_safety_checks) safety_checks.apply_setting(msg_safety_checks)
if display_rotation is not None: if display_rotation is not None:
await _require_confirm_change_display_rotation(ctx, display_rotation) await _require_confirm_change_display_rotation(display_rotation)
storage_device.set_rotation(display_rotation) storage_device.set_rotation(display_rotation)
if experimental_features is not None: if experimental_features is not None:
await _require_confirm_experimental_features(ctx, experimental_features) await _require_confirm_experimental_features(experimental_features)
storage_device.set_experimental_features(experimental_features) storage_device.set_experimental_features(experimental_features)
if hide_passphrase_from_host is not None: if hide_passphrase_from_host is not None:
if safety_checks.is_strict(): if safety_checks.is_strict():
raise ProcessError("Safety checks are strict") raise ProcessError("Safety checks are strict")
await _require_confirm_hide_passphrase_from_host(ctx, hide_passphrase_from_host) await _require_confirm_hide_passphrase_from_host(hide_passphrase_from_host)
storage_device.set_hide_passphrase_from_host(hide_passphrase_from_host) storage_device.set_hide_passphrase_from_host(hide_passphrase_from_host)
reload_settings_from_storage() reload_settings_from_storage()
@ -153,12 +150,9 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
return Success(message="Settings applied") return Success(message="Settings applied")
async def _require_confirm_change_homescreen( async def _require_confirm_change_homescreen(homescreen: bytes) -> None:
ctx: GenericContext, homescreen: bytes
) -> None:
if homescreen == b"": if homescreen == b"":
await confirm_action( await confirm_action(
ctx,
"set_homescreen", "set_homescreen",
"Set homescreen", "Set homescreen",
description="Do you really want to set default homescreen image?", description="Do you really want to set default homescreen image?",
@ -166,14 +160,12 @@ async def _require_confirm_change_homescreen(
) )
else: else:
await confirm_homescreen( await confirm_homescreen(
ctx,
homescreen, homescreen,
) )
async def _require_confirm_change_label(ctx: GenericContext, label: str) -> None: async def _require_confirm_change_label(label: str) -> None:
await confirm_single( await confirm_single(
ctx,
"set_label", "set_label",
"Device name", "Device name",
description="Change device name to {}?", description="Change device name to {}?",
@ -182,11 +174,10 @@ async def _require_confirm_change_label(ctx: GenericContext, label: str) -> None
) )
async def _require_confirm_change_passphrase(ctx: GenericContext, use: bool) -> None: async def _require_confirm_change_passphrase(use: bool) -> None:
template = "Do you want to {} passphrase protection?" template = "Do you want to {} passphrase protection?"
description = template.format("enable" if use else "disable") description = template.format("enable" if use else "disable")
await confirm_action( await confirm_action(
ctx,
"set_passphrase", "set_passphrase",
"Enable passphrase" if use else "Disable passphrase", "Enable passphrase" if use else "Disable passphrase",
description=description, description=description,
@ -195,7 +186,7 @@ async def _require_confirm_change_passphrase(ctx: GenericContext, use: bool) ->
async def _require_confirm_change_passphrase_source( async def _require_confirm_change_passphrase_source(
ctx: GenericContext, passphrase_always_on_device: bool passphrase_always_on_device: bool,
) -> None: ) -> None:
description = ( description = (
"Do you really want to enter passphrase always on the device?" "Do you really want to enter passphrase always on the device?"
@ -203,7 +194,6 @@ async def _require_confirm_change_passphrase_source(
else "Do you want to revoke the passphrase on device setting?" else "Do you want to revoke the passphrase on device setting?"
) )
await confirm_action( await confirm_action(
ctx,
"set_passphrase_source", "set_passphrase_source",
"Passphrase source", "Passphrase source",
description=description, description=description,
@ -211,9 +201,7 @@ async def _require_confirm_change_passphrase_source(
) )
async def _require_confirm_change_display_rotation( async def _require_confirm_change_display_rotation(rotation: int) -> None:
ctx: GenericContext, rotation: int
) -> None:
if rotation == 0: if rotation == 0:
label = "north" label = "north"
elif rotation == 90: elif rotation == 90:
@ -226,7 +214,6 @@ async def _require_confirm_change_display_rotation(
raise DataError("Unsupported display rotation") raise DataError("Unsupported display rotation")
await confirm_action( await confirm_action(
ctx,
"set_rotation", "set_rotation",
"Change rotation", "Change rotation",
description="Do you want to change device rotation to {}?", description="Do you want to change device rotation to {}?",
@ -235,13 +222,10 @@ async def _require_confirm_change_display_rotation(
) )
async def _require_confirm_change_autolock_delay( async def _require_confirm_change_autolock_delay(delay_ms: int) -> None:
ctx: GenericContext, delay_ms: int
) -> None:
from trezor.strings import format_duration_ms from trezor.strings import format_duration_ms
await confirm_action( await confirm_action(
ctx,
"set_autolock_delay", "set_autolock_delay",
"Auto-lock delay", "Auto-lock delay",
description="Do you really want to auto-lock your device after {}?", description="Do you really want to auto-lock your device after {}?",
@ -250,14 +234,11 @@ async def _require_confirm_change_autolock_delay(
) )
async def _require_confirm_safety_checks( async def _require_confirm_safety_checks(level: SafetyCheckLevel) -> None:
ctx: GenericContext, level: SafetyCheckLevel
) -> None:
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
if level == SafetyCheckLevel.Strict: if level == SafetyCheckLevel.Strict:
await confirm_action( await confirm_action(
ctx,
"set_safety_checks", "set_safety_checks",
"Safety checks", "Safety checks",
description="Do you really want to enforce strict safety checks (recommended)?", description="Do you really want to enforce strict safety checks (recommended)?",
@ -273,7 +254,6 @@ async def _require_confirm_safety_checks(
) )
await confirm_action( await confirm_action(
ctx,
"set_safety_checks", "set_safety_checks",
"Safety override", "Safety override",
"Are you sure?", "Are you sure?",
@ -287,12 +267,9 @@ async def _require_confirm_safety_checks(
raise ValueError # enum value out of range raise ValueError # enum value out of range
async def _require_confirm_experimental_features( async def _require_confirm_experimental_features(enable: bool) -> None:
ctx: GenericContext, enable: bool
) -> None:
if enable: if enable:
await confirm_action( await confirm_action(
ctx,
"set_experimental_features", "set_experimental_features",
"Experimental mode", "Experimental mode",
"Only for development and beta testing!", "Only for development and beta testing!",
@ -302,12 +279,9 @@ async def _require_confirm_experimental_features(
) )
async def _require_confirm_hide_passphrase_from_host( async def _require_confirm_hide_passphrase_from_host(enable: bool) -> None:
ctx: GenericContext, enable: bool
) -> None:
if enable: if enable:
await confirm_action( await confirm_action(
ctx,
"set_hide_passphrase_from_host", "set_hide_passphrase_from_host",
"Hide passphrase", "Hide passphrase",
description="Hide passphrase coming from host?", description="Hide passphrase coming from host?",

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import BackupDevice, Success from trezor.messages import BackupDevice, Success
from trezor.wire import Context
async def backup_device(ctx: Context, msg: BackupDevice) -> Success: async def backup_device(msg: BackupDevice) -> Success:
import storage.device as storage_device import storage.device as storage_device
from trezor import wire from trezor import wire
from trezor.messages import Success from trezor.messages import Success
@ -26,10 +25,10 @@ async def backup_device(ctx: Context, msg: BackupDevice) -> Success:
storage_device.set_unfinished_backup(True) storage_device.set_unfinished_backup(True)
storage_device.set_backed_up() storage_device.set_backed_up()
await backup_seed(ctx, mnemonic_type, mnemonic_secret) await backup_seed(mnemonic_type, mnemonic_secret)
storage_device.set_unfinished_backup(False) storage_device.set_unfinished_backup(False)
await layout.show_backup_success(ctx) await layout.show_backup_success()
return Success(message="Seed successfully backed up") return Success(message="Seed successfully backed up")

@ -6,10 +6,9 @@ if TYPE_CHECKING:
from typing import Awaitable from typing import Awaitable
from trezor.messages import ChangePin, Success from trezor.messages import ChangePin, Success
from trezor.wire import Context
async def change_pin(ctx: Context, msg: ChangePin) -> Success: async def change_pin(msg: ChangePin) -> Success:
from storage.device import is_initialized from storage.device import is_initialized
from trezor.messages import Success from trezor.messages import Success
from trezor.ui.layouts import show_success from trezor.ui.layouts import show_success
@ -25,28 +24,28 @@ async def change_pin(ctx: Context, msg: ChangePin) -> Success:
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
# confirm that user wants to change the pin # confirm that user wants to change the pin
await _require_confirm_change_pin(ctx, msg) await _require_confirm_change_pin(msg)
# get old pin # get old pin
curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") curpin, salt = await request_pin_and_sd_salt("Enter PIN")
# if changing pin, pre-check the entered pin before getting new pin # if changing pin, pre-check the entered pin before getting new pin
if curpin and not msg.remove: if curpin and not msg.remove:
if not config.check_pin(curpin, salt): if not config.check_pin(curpin, salt):
await error_pin_invalid(ctx) await error_pin_invalid()
# get new pin # get new pin
if not msg.remove: if not msg.remove:
newpin = await request_pin_confirm(ctx) newpin = await request_pin_confirm()
else: else:
newpin = "" newpin = ""
# write into storage # write into storage
if not config.change_pin(curpin, newpin, salt, salt): if not config.change_pin(curpin, newpin, salt, salt):
if newpin: if newpin:
await error_pin_matches_wipe_code(ctx) await error_pin_matches_wipe_code()
else: else:
await error_pin_invalid(ctx) await error_pin_invalid()
if newpin: if newpin:
if curpin: if curpin:
@ -59,11 +58,11 @@ async def change_pin(ctx: Context, msg: ChangePin) -> Success:
msg_screen = "PIN protection disabled." msg_screen = "PIN protection disabled."
msg_wire = "PIN removed" msg_wire = "PIN removed"
await show_success(ctx, "success_pin", msg_screen) await show_success("success_pin", msg_screen)
return Success(message=msg_wire) return Success(message=msg_wire)
def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]: def _require_confirm_change_pin(msg: ChangePin) -> Awaitable[None]:
from trezor.ui.layouts import confirm_action, confirm_set_new_pin from trezor.ui.layouts import confirm_action, confirm_set_new_pin
has_pin = config.has_pin() has_pin = config.has_pin()
@ -73,7 +72,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if msg.remove and has_pin: # removing pin if msg.remove and has_pin: # removing pin
return confirm_action( return confirm_action(
ctx,
br_type, br_type,
title, title,
description="Do you want to disable PIN protection?", description="Do you want to disable PIN protection?",
@ -82,7 +80,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if not msg.remove and has_pin: # changing pin if not msg.remove and has_pin: # changing pin
return confirm_action( return confirm_action(
ctx,
br_type, br_type,
title, title,
description="Do you want to change your PIN?", description="Do you want to change your PIN?",
@ -91,7 +88,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if not msg.remove and not has_pin: # setting new pin if not msg.remove and not has_pin: # setting new pin
return confirm_set_new_pin( return confirm_set_new_pin(
ctx,
br_type, br_type,
title, title,
"Do you want to enable PIN protection?", "Do you want to enable PIN protection?",

@ -2,12 +2,11 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable from typing import Awaitable
from trezor.wire import Context
from trezor.messages import ChangeWipeCode, Success from trezor.messages import ChangeWipeCode, Success
async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success: async def change_wipe_code(msg: ChangeWipeCode) -> Success:
from storage.device import is_initialized from storage.device import is_initialized
from trezor.wire import NotInitialized from trezor.wire import NotInitialized
from trezor.ui.layouts import show_success from trezor.ui.layouts import show_success
@ -23,24 +22,24 @@ async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success:
# Confirm that user wants to set or remove the wipe code. # Confirm that user wants to set or remove the wipe code.
has_wipe_code = config.has_wipe_code() has_wipe_code = config.has_wipe_code()
await _require_confirm_action(ctx, msg, has_wipe_code) await _require_confirm_action(msg, has_wipe_code)
# Get the unlocking PIN. # Get the unlocking PIN.
pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") pin, salt = await request_pin_and_sd_salt("Enter PIN")
if not msg.remove: if not msg.remove:
# Pre-check the entered PIN. # Pre-check the entered PIN.
if config.has_pin() and not config.check_pin(pin, salt): if config.has_pin() and not config.check_pin(pin, salt):
await error_pin_invalid(ctx) await error_pin_invalid()
# Get new wipe code. # Get new wipe code.
wipe_code = await _request_wipe_code_confirm(ctx, pin) wipe_code = await _request_wipe_code_confirm(pin)
else: else:
wipe_code = "" wipe_code = ""
# Write into storage. # Write into storage.
if not config.change_wipe_code(pin, salt, wipe_code): if not config.change_wipe_code(pin, salt, wipe_code):
await error_pin_invalid(ctx) await error_pin_invalid()
if wipe_code: if wipe_code:
if has_wipe_code: if has_wipe_code:
@ -53,12 +52,12 @@ async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success:
msg_screen = "Wipe code disabled." msg_screen = "Wipe code disabled."
msg_wire = "Wipe code removed" msg_wire = "Wipe code removed"
await show_success(ctx, "success_wipe_code", msg_screen) await show_success("success_wipe_code", msg_screen)
return Success(message=msg_wire) return Success(message=msg_wire)
def _require_confirm_action( def _require_confirm_action(
ctx: Context, msg: ChangeWipeCode, has_wipe_code: bool msg: ChangeWipeCode, has_wipe_code: bool
) -> Awaitable[None]: ) -> Awaitable[None]:
from trezor.wire import ProcessError from trezor.wire import ProcessError
from trezor.ui.layouts import confirm_action, confirm_set_new_pin from trezor.ui.layouts import confirm_action, confirm_set_new_pin
@ -67,7 +66,6 @@ def _require_confirm_action(
if msg.remove and has_wipe_code: if msg.remove and has_wipe_code:
return confirm_action( return confirm_action(
ctx,
"disable_wipe_code", "disable_wipe_code",
title, title,
description="Do you want to disable wipe code protection?", description="Do you want to disable wipe code protection?",
@ -76,7 +74,6 @@ def _require_confirm_action(
if not msg.remove and has_wipe_code: if not msg.remove and has_wipe_code:
return confirm_action( return confirm_action(
ctx,
"change_wipe_code", "change_wipe_code",
title, title,
description="Do you want to change the wipe code?", description="Do you want to change the wipe code?",
@ -85,7 +82,6 @@ def _require_confirm_action(
if not msg.remove and not has_wipe_code: if not msg.remove and not has_wipe_code:
return confirm_set_new_pin( return confirm_set_new_pin(
ctx,
"set_wipe_code", "set_wipe_code",
title, title,
"Do you want to enable wipe code?", "Do you want to enable wipe code?",
@ -98,7 +94,7 @@ def _require_confirm_action(
raise ProcessError("Wipe code protection is already disabled") raise ProcessError("Wipe code protection is already disabled")
async def _request_wipe_code_confirm(ctx: Context, pin: str) -> str: async def _request_wipe_code_confirm(pin: str) -> str:
from apps.common.request_pin import request_pin from apps.common.request_pin import request_pin
from trezor.ui.layouts import ( from trezor.ui.layouts import (
confirm_reenter_pin, confirm_reenter_pin,
@ -107,12 +103,12 @@ async def _request_wipe_code_confirm(ctx: Context, pin: str) -> str:
) )
while True: while True:
code1 = await request_pin(ctx, "Enter new wipe code") code1 = await request_pin("Enter new wipe code")
if code1 == pin: if code1 == pin:
await wipe_code_same_as_pin_popup(ctx) await wipe_code_same_as_pin_popup()
continue continue
await confirm_reenter_pin(ctx, is_wipe_code=True) await confirm_reenter_pin(is_wipe_code=True)
code2 = await request_pin(ctx, "Re-enter wipe code") code2 = await request_pin("Re-enter wipe code")
if code1 == code2: if code1 == code2:
return code1 return code1
await pin_mismatch_popup(ctx, is_wipe_code=True) await pin_mismatch_popup(is_wipe_code=True)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetNextU2FCounter, NextU2FCounter from trezor.messages import GetNextU2FCounter, NextU2FCounter
from trezor.wire import Context
async def get_next_u2f_counter(ctx: Context, msg: GetNextU2FCounter) -> NextU2FCounter: async def get_next_u2f_counter(msg: GetNextU2FCounter) -> NextU2FCounter:
import storage.device as storage_device import storage.device as storage_device
from trezor.wire import NotInitialized from trezor.wire import NotInitialized
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
@ -16,7 +15,6 @@ async def get_next_u2f_counter(ctx: Context, msg: GetNextU2FCounter) -> NextU2FC
raise NotInitialized("Device is not initialized") raise NotInitialized("Device is not initialized")
await confirm_action( await confirm_action(
ctx,
"get_u2f_counter", "get_u2f_counter",
"Get next U2F counter", "Get next U2F counter",
description="Do you really want to increase and retrieve the U2F counter?", description="Do you really want to increase and retrieve the U2F counter?",

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetNonce, Nonce from trezor.messages import GetNonce, Nonce
from trezor.wire import Context
async def get_nonce(ctx: Context, msg: GetNonce) -> Nonce: async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache from storage import cache
from trezor.crypto import random from trezor.crypto import random
from trezor.messages import Nonce from trezor.messages import Nonce

@ -3,21 +3,21 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import RebootToBootloader from trezor.messages import RebootToBootloader
from typing import NoReturn from typing import NoReturn
from trezor.wire import Context
async def reboot_to_bootloader(ctx: Context, msg: RebootToBootloader) -> NoReturn: async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
from trezor import io, loop, utils from trezor import io, loop, utils
from trezor.messages import Success from trezor.messages import Success
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
from trezor.wire.context import get_context
await confirm_action( await confirm_action(
ctx,
"reboot", "reboot",
"Go to bootloader", "Go to bootloader",
"Do you want to restart Trezor in bootloader mode?", "Do you want to restart Trezor in bootloader mode?",
verb="Restart", verb="Restart",
) )
ctx = get_context()
await ctx.write(Success(message="Rebooting")) await ctx.write(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed # make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)

@ -2,7 +2,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import RecoveryDevice from trezor.messages import RecoveryDevice
from trezor.wire import Context
from trezor.messages import Success from trezor.messages import Success
# List of RecoveryDevice fields that can be set when doing dry-run recovery. # List of RecoveryDevice fields that can be set when doing dry-run recovery.
@ -11,7 +10,7 @@ if TYPE_CHECKING:
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type") DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success: async def recovery_device(msg: RecoveryDevice) -> Success:
""" """
Recover BIP39/SLIP39 seed into empty device. Recover BIP39/SLIP39 seed into empty device.
Recovery is also possible with replugged Trezor. We call this process Persistence. Recovery is also possible with replugged Trezor. We call this process Persistence.
@ -52,15 +51,14 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
# -------------------------------------------------------- # --------------------------------------------------------
if storage_recovery.is_in_progress(): if storage_recovery.is_in_progress():
return await recovery_process(ctx) return await recovery_process()
# -------------------------------------------------------- # --------------------------------------------------------
# _continue_dialog # _continue_dialog
if not dry_run: if not dry_run:
await confirm_reset_device(ctx, "Wallet recovery", recovery=True) await confirm_reset_device("Wallet recovery", recovery=True)
else: else:
await confirm_action( await confirm_action(
ctx,
"confirm_seedcheck", "confirm_seedcheck",
"Seed check", "Seed check",
description="Do you really want to check the recovery seed?", description="Do you really want to check the recovery seed?",
@ -75,14 +73,14 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
# for dry run pin needs to be entered # for dry run pin needs to be entered
if dry_run: if dry_run:
curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") curpin, salt = await request_pin_and_sd_salt("Enter PIN")
if not config.check_pin(curpin, salt): if not config.check_pin(curpin, salt):
await error_pin_invalid(ctx) await error_pin_invalid()
if not dry_run: if not dry_run:
# set up pin if requested # set up pin if requested
if msg.pin_protection: if msg.pin_protection:
newpin = await request_pin_confirm(ctx, allow_cancel=False) newpin = await request_pin_confirm(allow_cancel=False)
config.change_pin("", newpin, None, None) config.change_pin("", newpin, None, None)
storage_device.set_passphrase_enabled(bool(msg.passphrase_protection)) storage_device.set_passphrase_enabled(bool(msg.passphrase_protection))
@ -95,4 +93,4 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
storage_recovery.set_dry_run(bool(dry_run)) storage_recovery.set_dry_run(bool(dry_run))
workflow.set_default(recovery_homescreen) workflow.set_default(recovery_homescreen)
return await recovery_process(ctx) return await recovery_process()

@ -9,7 +9,6 @@ from .. import backup_types
from . import layout, recover from . import layout, recover
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import GenericContext
from trezor.enums import BackupType from trezor.enums import BackupType
@ -21,18 +20,16 @@ async def recovery_homescreen() -> None:
workflow.set_default(homescreen) workflow.set_default(homescreen)
return return
# recovery process does not communicate on the wire await recovery_process()
ctx = wire.DUMMY_CONTEXT
await recovery_process(ctx)
async def recovery_process(ctx: GenericContext) -> Success: async def recovery_process() -> Success:
from trezor.enums import MessageType from trezor.enums import MessageType
import storage import storage
wire.AVOID_RESTARTING_FOR = (MessageType.Initialize, MessageType.GetFeatures) wire.AVOID_RESTARTING_FOR = (MessageType.Initialize, MessageType.GetFeatures)
try: try:
return await _continue_recovery_process(ctx) return await _continue_recovery_process()
except recover.RecoveryAborted: except recover.RecoveryAborted:
dry_run = storage_recovery.is_dry_run() dry_run = storage_recovery.is_dry_run()
if dry_run: if dry_run:
@ -42,7 +39,7 @@ async def recovery_process(ctx: GenericContext) -> Success:
raise wire.ActionCancelled raise wire.ActionCancelled
async def _continue_recovery_process(ctx: GenericContext) -> Success: async def _continue_recovery_process() -> Success:
from trezor.errors import MnemonicError from trezor.errors import MnemonicError
# gather the current recovery state from storage # gather the current recovery state from storage
@ -58,48 +55,46 @@ async def _continue_recovery_process(ctx: GenericContext) -> Success:
if not is_first_step: if not is_first_step:
assert word_count is not None assert word_count is not None
# If we continue recovery, show starting screen with word count immediately. # If we continue recovery, show starting screen with word count immediately.
await _request_share_first_screen(ctx, word_count) await _request_share_first_screen(word_count)
secret = None secret = None
while secret is None: while secret is None:
if is_first_step: if is_first_step:
# If we are starting recovery, ask for word count first... # If we are starting recovery, ask for word count first...
# _request_word_count # _request_word_count
await layout.homescreen_dialog(ctx, "Select", "Select number of words") await layout.homescreen_dialog("Select", "Select number of words")
# ask for the number of words # ask for the number of words
word_count = await layout.request_word_count(ctx, dry_run) word_count = await layout.request_word_count(dry_run)
# ...and only then show the starting screen with word count. # ...and only then show the starting screen with word count.
await _request_share_first_screen(ctx, word_count) await _request_share_first_screen(word_count)
assert word_count is not None assert word_count is not None
# ask for mnemonic words one by one # ask for mnemonic words one by one
words = await layout.request_mnemonic(ctx, word_count, backup_type) words = await layout.request_mnemonic(word_count, backup_type)
# if they were invalid or some checks failed we continue and request them again # if they were invalid or some checks failed we continue and request them again
if not words: if not words:
continue continue
try: try:
secret, backup_type = await _process_words(ctx, words) secret, backup_type = await _process_words(words)
# If _process_words succeeded, we now have both backup_type (from # If _process_words succeeded, we now have both backup_type (from
# its result) and word_count (from _request_word_count earlier), which means # its result) and word_count (from _request_word_count earlier), which means
# that the first step is complete. # that the first step is complete.
is_first_step = False is_first_step = False
except MnemonicError: except MnemonicError:
await layout.show_invalid_mnemonic(ctx, word_count) await layout.show_invalid_mnemonic(word_count)
assert backup_type is not None assert backup_type is not None
if dry_run: if dry_run:
result = await _finish_recovery_dry_run(ctx, secret, backup_type) result = await _finish_recovery_dry_run(secret, backup_type)
else: else:
result = await _finish_recovery(ctx, secret, backup_type) result = await _finish_recovery(secret, backup_type)
return result return result
async def _finish_recovery_dry_run( async def _finish_recovery_dry_run(secret: bytes, backup_type: BackupType) -> Success:
ctx: GenericContext, secret: bytes, backup_type: BackupType
) -> Success:
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor import utils from trezor import utils
from apps.common import mnemonic from apps.common import mnemonic
@ -126,7 +121,7 @@ async def _finish_recovery_dry_run(
storage_recovery.end_progress() storage_recovery.end_progress()
await layout.show_dry_run_result(ctx, result, is_slip39) await layout.show_dry_run_result(result, is_slip39)
if result: if result:
return Success(message="The seed is valid and matches the one in the device") return Success(message="The seed is valid and matches the one in the device")
@ -134,9 +129,7 @@ async def _finish_recovery_dry_run(
raise wire.ProcessError("The seed does not match the one in the device") raise wire.ProcessError("The seed does not match the one in the device")
async def _finish_recovery( async def _finish_recovery(secret: bytes, backup_type: BackupType) -> Success:
ctx: GenericContext, secret: bytes, backup_type: BackupType
) -> Success:
from trezor.ui.layouts import show_success from trezor.ui.layouts import show_success
from trezor.enums import BackupType from trezor.enums import BackupType
@ -157,15 +150,11 @@ async def _finish_recovery(
storage_recovery.end_progress() storage_recovery.end_progress()
await show_success( await show_success("success_recovery", "You have finished recovering your wallet.")
ctx, "success_recovery", "You have finished recovering your wallet."
)
return Success(message="Device recovered") return Success(message="Device recovered")
async def _process_words( async def _process_words(words: str) -> tuple[bytes | None, BackupType]:
ctx: GenericContext, words: str
) -> tuple[bytes | None, BackupType]:
word_count = len(words.split(" ")) word_count = len(words.split(" "))
is_slip39 = backup_types.is_slip39_word_count(word_count) is_slip39 = backup_types.is_slip39_word_count(word_count)
@ -179,28 +168,28 @@ async def _process_words(
if secret is None: # SLIP-39 if secret is None: # SLIP-39
assert share is not None assert share is not None
if share.group_count and share.group_count > 1: if share.group_count and share.group_count > 1:
await layout.show_group_share_success(ctx, share.index, share.group_index) await layout.show_group_share_success(share.index, share.group_index)
await _request_share_next_screen(ctx) await _request_share_next_screen()
return secret, backup_type return secret, backup_type
async def _request_share_first_screen(ctx: GenericContext, word_count: int) -> None: async def _request_share_first_screen(word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count): if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage_recovery.fetch_slip39_remaining_shares()
if remaining: if remaining:
await _request_share_next_screen(ctx) await _request_share_next_screen()
else: else:
await layout.homescreen_dialog( await layout.homescreen_dialog(
ctx, "Enter share", "Enter any share", f"({word_count} words)" "Enter share", "Enter any share", f"({word_count} words)"
) )
else: # BIP-39 else: # BIP-39
await layout.homescreen_dialog( await layout.homescreen_dialog(
ctx, "Enter seed", "Enter recovery seed", f"({word_count} words)" "Enter seed", "Enter recovery seed", f"({word_count} words)"
) )
async def _request_share_next_screen(ctx: GenericContext) -> None: async def _request_share_next_screen() -> None:
from trezor import strings from trezor import strings
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage_recovery.fetch_slip39_remaining_shares()
@ -211,17 +200,16 @@ async def _request_share_next_screen(ctx: GenericContext) -> None:
if group_count > 1: if group_count > 1:
await layout.homescreen_dialog( await layout.homescreen_dialog(
ctx,
"Enter", "Enter",
"More shares needed", "More shares needed",
info_func=_show_remaining_groups_and_shares, info_func=_show_remaining_groups_and_shares,
) )
else: else:
text = strings.format_plural("{count} more {plural}", remaining[0], "share") text = strings.format_plural("{count} more {plural}", remaining[0], "share")
await layout.homescreen_dialog(ctx, "Enter share", text, "needed to enter") await layout.homescreen_dialog("Enter share", text, "needed to enter")
async def _show_remaining_groups_and_shares(ctx: GenericContext) -> None: async def _show_remaining_groups_and_shares() -> None:
""" """
Show info dialog for Slip39 Advanced - what shares are to be entered. Show info dialog for Slip39 Advanced - what shares are to be entered.
""" """
@ -254,5 +242,5 @@ async def _show_remaining_groups_and_shares(ctx: GenericContext) -> None:
assert share # share needs to be set assert share # share needs to be set
return await layout.show_remaining_shares( return await layout.show_remaining_shares(
ctx, groups, shares_remaining, share.group_threshold groups, shares_remaining, share.group_threshold
) )

@ -14,13 +14,11 @@ from .. import backup_types
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable from typing import Callable
from trezor.enums import BackupType from trezor.enums import BackupType
from trezor.wire import GenericContext
async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None: async def _confirm_abort(dry_run: bool = False) -> None:
if dry_run: if dry_run:
await confirm_action( await confirm_action(
ctx,
"abort_recovery", "abort_recovery",
"Abort seed check", "Abort seed check",
description="Do you really want to abort the seed check?", description="Do you really want to abort the seed check?",
@ -28,7 +26,6 @@ async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None:
) )
else: else:
await confirm_action( await confirm_action(
ctx,
"abort_recovery", "abort_recovery",
"Abort recovery", "Abort recovery",
"All progress will be lost.", "All progress will be lost.",
@ -39,21 +36,21 @@ async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None:
async def request_mnemonic( async def request_mnemonic(
ctx: GenericContext, word_count: int, backup_type: BackupType | None word_count: int, backup_type: BackupType | None
) -> str | None: ) -> str | None:
from . import word_validity from . import word_validity
from trezor.ui.layouts.common import button_request from trezor.ui.layouts.common import button_request
from trezor.ui.layouts.recovery import request_word from trezor.ui.layouts.recovery import request_word
from trezor.ui.layouts import mnemonic_word_entering from trezor.ui.layouts import mnemonic_word_entering
await mnemonic_word_entering(ctx) await mnemonic_word_entering()
await button_request(ctx, "mnemonic", code=ButtonRequestType.MnemonicInput) await button_request("mnemonic", code=ButtonRequestType.MnemonicInput)
words: list[str] = [] words: list[str] = []
for i in range(word_count): for i in range(word_count):
word = await request_word( word = await request_word(
ctx, i, word_count, is_slip39=backup_types.is_slip39_word_count(word_count) i, word_count, is_slip39=backup_types.is_slip39_word_count(word_count)
) )
words.append(word) words.append(word)
@ -62,7 +59,6 @@ async def request_mnemonic(
except word_validity.AlreadyAdded: except word_validity.AlreadyAdded:
# show_share_already_added # show_share_already_added
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_known_share", "warning_known_share",
"Share already entered, please enter a different share.", "Share already entered, please enter a different share.",
) )
@ -70,7 +66,6 @@ async def request_mnemonic(
except word_validity.IdentifierMismatch: except word_validity.IdentifierMismatch:
# show_identifier_mismatch # show_identifier_mismatch
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_mismatched_share", "warning_mismatched_share",
"You have entered a share from another Shamir Backup.", "You have entered a share from another Shamir Backup.",
) )
@ -78,7 +73,6 @@ async def request_mnemonic(
except word_validity.ThresholdReached: except word_validity.ThresholdReached:
# show_group_threshold_reached # show_group_threshold_reached
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_group_threshold", "warning_group_threshold",
"Threshold of this group has been reached. Input share from different group.", "Threshold of this group has been reached. Input share from different group.",
) )
@ -87,9 +81,7 @@ async def request_mnemonic(
return " ".join(words) return " ".join(words)
async def show_dry_run_result( async def show_dry_run_result(result: bool, is_slip39: bool) -> None:
ctx: GenericContext, result: bool, is_slip39: bool
) -> None:
from trezor.ui.layouts import show_success from trezor.ui.layouts import show_success
if result: if result:
@ -99,34 +91,29 @@ async def show_dry_run_result(
text = ( text = (
"The entered recovery seed is valid and matches the one in the device." "The entered recovery seed is valid and matches the one in the device."
) )
await show_success(ctx, "success_dry_recovery", text, button="Continue") await show_success("success_dry_recovery", text, button="Continue")
else: else:
if is_slip39: if is_slip39:
text = "The entered recovery shares are valid but do not match what is currently in the device." text = "The entered recovery shares are valid but do not match what is currently in the device."
else: else:
text = "The entered recovery seed is valid but does not match the one in the device." text = "The entered recovery seed is valid but does not match the one in the device."
await show_recovery_warning( await show_recovery_warning("warning_dry_recovery", text, button="Continue")
ctx, "warning_dry_recovery", text, button="Continue"
)
async def show_invalid_mnemonic(ctx: GenericContext, word_count: int) -> None: async def show_invalid_mnemonic(word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count): if backup_types.is_slip39_word_count(word_count):
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_invalid_share", "warning_invalid_share",
"You have entered an invalid recovery share.", "You have entered an invalid recovery share.",
) )
else: else:
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_invalid_seed", "warning_invalid_seed",
"You have entered an invalid recovery seed.", "You have entered an invalid recovery seed.",
) )
async def homescreen_dialog( async def homescreen_dialog(
ctx: GenericContext,
button_label: str, button_label: str,
text: str, text: str,
subtext: str | None = None, subtext: str | None = None,
@ -139,14 +126,12 @@ async def homescreen_dialog(
while True: while True:
dry_run = storage_recovery.is_dry_run() dry_run = storage_recovery.is_dry_run()
if await continue_recovery( if await continue_recovery(button_label, text, subtext, info_func, dry_run):
ctx, button_label, text, subtext, info_func, dry_run
):
# go forward in the recovery process # go forward in the recovery process
break break
# user has chosen to abort, confirm the choice # user has chosen to abort, confirm the choice
try: try:
await _confirm_abort(ctx, dry_run) await _confirm_abort(dry_run)
except ActionCancelled: except ActionCancelled:
pass pass
else: else:

@ -13,7 +13,6 @@ if __debug__:
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ResetDevice from trezor.messages import ResetDevice
from trezor.wire import Context
from trezor.messages import Success from trezor.messages import Success
@ -23,7 +22,7 @@ BAK_T_SLIP39_ADVANCED = BackupType.Slip39_Advanced # global_import_cache
_DEFAULT_BACKUP_TYPE = BAK_T_BIP39 _DEFAULT_BACKUP_TYPE = BAK_T_BIP39
async def reset_device(ctx: Context, msg: ResetDevice) -> Success: async def reset_device(msg: ResetDevice) -> Success:
from trezor import config, utils from trezor import config, utils
from apps.common.request_pin import request_pin_confirm from apps.common.request_pin import request_pin_confirm
from trezor.ui.layouts import ( from trezor.ui.layouts import (
@ -33,6 +32,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
from trezor.crypto import bip39, random from trezor.crypto import bip39, random
from trezor.messages import Success, EntropyAck, EntropyRequest from trezor.messages import Success, EntropyAck, EntropyRequest
from trezor.pin import render_empty_loader from trezor.pin import render_empty_loader
from trezor.wire.context import call
backup_type = msg.backup_type # local_cache_attribute backup_type = msg.backup_type # local_cache_attribute
@ -49,7 +49,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
title = f"Create wallet{delimiter}(Super Shamir)" title = f"Create wallet{delimiter}(Super Shamir)"
else: else:
title = "Create wallet" title = "Create wallet"
await confirm_reset_device(ctx, title) await confirm_reset_device(title)
# Rendering empty loader so users do not feel a freezing screen # Rendering empty loader so users do not feel a freezing screen
render_empty_loader("PROCESSING", "") render_empty_loader("PROCESSING", "")
@ -59,7 +59,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# request and set new PIN # request and set new PIN
if msg.pin_protection: if msg.pin_protection:
newpin = await request_pin_confirm(ctx) newpin = await request_pin_confirm()
if not config.change_pin("", newpin, None, None): if not config.change_pin("", newpin, None, None):
raise ProcessError("Failed to set PIN") raise ProcessError("Failed to set PIN")
@ -68,10 +68,10 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
if __debug__: if __debug__:
storage.debug.reset_internal_entropy = int_entropy storage.debug.reset_internal_entropy = int_entropy
if msg.display_random: if msg.display_random:
await layout.show_internal_entropy(ctx, int_entropy) await layout.show_internal_entropy(int_entropy)
# request external entropy and compute the master secret # request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), EntropyAck) entropy_ack = await call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy ext_entropy = entropy_ack.entropy
# For SLIP-39 this is the Encrypted Master Secret # For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
@ -94,11 +94,11 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# If doing backup, ask the user to confirm. # If doing backup, ask the user to confirm.
if perform_backup: if perform_backup:
perform_backup = await confirm_backup(ctx) perform_backup = await confirm_backup()
# generate and display backup information for the master secret # generate and display backup information for the master secret
if perform_backup: if perform_backup:
await backup_seed(ctx, backup_type, secret) await backup_seed(backup_type, secret)
# write settings and master secret into storage # write settings and master secret into storage
if msg.label is not None: if msg.label is not None:
@ -113,19 +113,19 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# if we backed up the wallet, show success message # if we backed up the wallet, show success message
if perform_backup: if perform_backup:
await layout.show_backup_success(ctx) await layout.show_backup_success()
return Success(message="Initialized") return Success(message="Initialized")
async def _backup_slip39_basic(ctx: Context, encrypted_master_secret: bytes) -> None: async def _backup_slip39_basic(encrypted_master_secret: bytes) -> None:
# get number of shares # get number of shares
await layout.slip39_show_checklist(ctx, 0, BAK_T_SLIP39_BASIC) await layout.slip39_show_checklist(0, BAK_T_SLIP39_BASIC)
shares_count = await layout.slip39_prompt_number_of_shares(ctx) shares_count = await layout.slip39_prompt_number_of_shares()
# get threshold # get threshold
await layout.slip39_show_checklist(ctx, 1, BAK_T_SLIP39_BASIC) await layout.slip39_show_checklist(1, BAK_T_SLIP39_BASIC)
threshold = await layout.slip39_prompt_threshold(ctx, shares_count) threshold = await layout.slip39_prompt_threshold(shares_count)
identifier = storage_device.get_slip39_identifier() identifier = storage_device.get_slip39_identifier()
iteration_exponent = storage_device.get_slip39_iteration_exponent() iteration_exponent = storage_device.get_slip39_iteration_exponent()
@ -142,27 +142,25 @@ async def _backup_slip39_basic(ctx: Context, encrypted_master_secret: bytes) ->
)[0] )[0]
# show and confirm individual shares # show and confirm individual shares
await layout.slip39_show_checklist(ctx, 2, BAK_T_SLIP39_BASIC) await layout.slip39_show_checklist(2, BAK_T_SLIP39_BASIC)
await layout.slip39_basic_show_and_confirm_shares(ctx, mnemonics) await layout.slip39_basic_show_and_confirm_shares(mnemonics)
async def _backup_slip39_advanced(ctx: Context, encrypted_master_secret: bytes) -> None: async def _backup_slip39_advanced(encrypted_master_secret: bytes) -> None:
# get number of groups # get number of groups
await layout.slip39_show_checklist(ctx, 0, BAK_T_SLIP39_ADVANCED) await layout.slip39_show_checklist(0, BAK_T_SLIP39_ADVANCED)
groups_count = await layout.slip39_advanced_prompt_number_of_groups(ctx) groups_count = await layout.slip39_advanced_prompt_number_of_groups()
# get group threshold # get group threshold
await layout.slip39_show_checklist(ctx, 1, BAK_T_SLIP39_ADVANCED) await layout.slip39_show_checklist(1, BAK_T_SLIP39_ADVANCED)
group_threshold = await layout.slip39_advanced_prompt_group_threshold( group_threshold = await layout.slip39_advanced_prompt_group_threshold(groups_count)
ctx, groups_count
)
# get shares and thresholds # get shares and thresholds
await layout.slip39_show_checklist(ctx, 2, BAK_T_SLIP39_ADVANCED) await layout.slip39_show_checklist(2, BAK_T_SLIP39_ADVANCED)
groups = [] groups = []
for i in range(groups_count): for i in range(groups_count):
share_count = await layout.slip39_prompt_number_of_shares(ctx, i) share_count = await layout.slip39_prompt_number_of_shares(i)
share_threshold = await layout.slip39_prompt_threshold(ctx, share_count, i) share_threshold = await layout.slip39_prompt_threshold(share_count, i)
groups.append((share_threshold, share_count)) groups.append((share_threshold, share_count))
identifier = storage_device.get_slip39_identifier() identifier = storage_device.get_slip39_identifier()
@ -180,7 +178,7 @@ async def _backup_slip39_advanced(ctx: Context, encrypted_master_secret: bytes)
) )
# show and confirm individual shares # show and confirm individual shares
await layout.slip39_advanced_show_and_confirm_shares(ctx, mnemonics) await layout.slip39_advanced_show_and_confirm_shares(mnemonics)
def _validate_reset_device(msg: ResetDevice) -> None: def _validate_reset_device(msg: ResetDevice) -> None:
@ -222,12 +220,10 @@ def _compute_secret_from_entropy(
return secret return secret
async def backup_seed( async def backup_seed(backup_type: BackupType, mnemonic_secret: bytes) -> None:
ctx: Context, backup_type: BackupType, mnemonic_secret: bytes
) -> None:
if backup_type == BAK_T_SLIP39_BASIC: if backup_type == BAK_T_SLIP39_BASIC:
await _backup_slip39_basic(ctx, mnemonic_secret) await _backup_slip39_basic(mnemonic_secret)
elif backup_type == BAK_T_SLIP39_ADVANCED: elif backup_type == BAK_T_SLIP39_ADVANCED:
await _backup_slip39_advanced(ctx, mnemonic_secret) await _backup_slip39_advanced(mnemonic_secret)
else: else:
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode()) await layout.bip39_show_and_confirm_mnemonic(mnemonic_secret.decode())

@ -1,5 +1,5 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import Sequence
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.ui.layouts import show_success from trezor.ui.layouts import show_success
@ -12,18 +12,13 @@ from trezor.ui.layouts.reset import ( # noqa: F401
slip39_show_checklist, slip39_show_checklist,
) )
if TYPE_CHECKING:
from typing import Sequence
from trezor.wire import GenericContext
_NUM_OF_CHOICES = const(3) _NUM_OF_CHOICES = const(3)
async def show_internal_entropy(ctx: GenericContext, entropy: bytes) -> None: async def show_internal_entropy(entropy: bytes) -> None:
from trezor.ui.layouts import confirm_blob from trezor.ui.layouts import confirm_blob
await confirm_blob( await confirm_blob(
ctx,
"entropy", "entropy",
"Internal entropy", "Internal entropy",
entropy, entropy,
@ -32,7 +27,6 @@ async def show_internal_entropy(ctx: GenericContext, entropy: bytes) -> None:
async def _confirm_word( async def _confirm_word(
ctx: GenericContext,
share_index: int | None, share_index: int | None,
share_words: Sequence[str], share_words: Sequence[str],
offset: int, offset: int,
@ -56,14 +50,13 @@ async def _confirm_word(
random.shuffle(choices) random.shuffle(choices)
# let the user pick a word # let the user pick a word
selected_word: str = await select_word( selected_word: str = await select_word(
ctx, choices, share_index, checked_index, count, group_index choices, share_index, checked_index, count, group_index
) )
# confirm it is the correct one # confirm it is the correct one
return selected_word == checked_word return selected_word == checked_word
async def _share_words_confirmed( async def _share_words_confirmed(
ctx: GenericContext,
share_index: int | None, share_index: int | None,
share_words: Sequence[str], share_words: Sequence[str],
num_of_shares: int | None = None, num_of_shares: int | None = None,
@ -77,22 +70,20 @@ async def _share_words_confirmed(
""" """
# TODO: confirm_action("Select the words bla bla") # TODO: confirm_action("Select the words bla bla")
if await _do_confirm_share_words(ctx, share_index, share_words, group_index): if await _do_confirm_share_words(share_index, share_words, group_index):
await _show_confirmation_success( await _show_confirmation_success(
ctx,
share_index, share_index,
num_of_shares, num_of_shares,
group_index, group_index,
) )
return True return True
else: else:
await _show_confirmation_failure(ctx) await _show_confirmation_failure()
return False return False
async def _do_confirm_share_words( async def _do_confirm_share_words(
ctx: GenericContext,
share_index: int | None, share_index: int | None,
share_words: Sequence[str], share_words: Sequence[str],
group_index: int | None = None, group_index: int | None = None,
@ -106,7 +97,7 @@ async def _do_confirm_share_words(
offset = 0 offset = 0
count = len(share_words) count = len(share_words)
for part in utils.chunks(share_words, third): for part in utils.chunks(share_words, third):
if not await _confirm_word(ctx, share_index, part, offset, count, group_index): if not await _confirm_word(share_index, part, offset, count, group_index):
return False return False
offset += len(part) offset += len(part)
@ -114,7 +105,6 @@ async def _do_confirm_share_words(
async def _show_confirmation_success( async def _show_confirmation_success(
ctx: GenericContext,
share_index: int | None = None, share_index: int | None = None,
num_of_shares: int | None = None, num_of_shares: int | None = None,
group_index: int | None = None, group_index: int | None = None,
@ -138,14 +128,13 @@ async def _show_confirmation_success(
subheader = f"Group {group_index + 1} - Share {share_index + 1} checked successfully." subheader = f"Group {group_index + 1} - Share {share_index + 1} checked successfully."
text = "Continue with the next share." text = "Continue with the next share."
return await show_success(ctx, "success_recovery", text, subheader) return await show_success("success_recovery", text, subheader)
async def _show_confirmation_failure(ctx: GenericContext) -> None: async def _show_confirmation_failure() -> None:
from trezor.ui.layouts.recovery import show_recovery_warning from trezor.ui.layouts.recovery import show_recovery_warning
await show_recovery_warning( await show_recovery_warning(
ctx,
"warning_backup_check", "warning_backup_check",
"Please check again.", "Please check again.",
"That is the wrong word.", "That is the wrong word.",
@ -154,34 +143,34 @@ async def _show_confirmation_failure(ctx: GenericContext) -> None:
) )
async def show_backup_warning(ctx: GenericContext, slip39: bool = False) -> None: async def show_backup_warning(slip39: bool = False) -> None:
from trezor.ui.layouts.reset import show_warning_backup from trezor.ui.layouts.reset import show_warning_backup
await show_warning_backup(ctx, slip39) await show_warning_backup(slip39)
async def show_backup_success(ctx: GenericContext) -> None: async def show_backup_success() -> None:
from trezor.ui.layouts.reset import show_success_backup from trezor.ui.layouts.reset import show_success_backup
await show_success_backup(ctx) await show_success_backup()
# BIP39 # BIP39
# === # ===
async def bip39_show_and_confirm_mnemonic(ctx: GenericContext, mnemonic: str) -> None: async def bip39_show_and_confirm_mnemonic(mnemonic: str) -> None:
# warn user about mnemonic safety # warn user about mnemonic safety
await show_backup_warning(ctx) await show_backup_warning()
words = mnemonic.split() words = mnemonic.split()
while True: while True:
# display paginated mnemonic on the screen # display paginated mnemonic on the screen
await show_share_words(ctx, words) await show_share_words(words)
# make the user confirm some words from the mnemonic # make the user confirm some words from the mnemonic
if await _share_words_confirmed(ctx, None, words): if await _share_words_confirmed(None, words):
break # this share is confirmed, go to next one break # this share is confirmed, go to next one
@ -189,38 +178,36 @@ async def bip39_show_and_confirm_mnemonic(ctx: GenericContext, mnemonic: str) ->
# === # ===
async def slip39_basic_show_and_confirm_shares( async def slip39_basic_show_and_confirm_shares(shares: Sequence[str]) -> None:
ctx: GenericContext, shares: Sequence[str]
) -> None:
# warn user about mnemonic safety # warn user about mnemonic safety
await show_backup_warning(ctx, True) await show_backup_warning(True)
for index, share in enumerate(shares): for index, share in enumerate(shares):
share_words = share.split(" ") share_words = share.split(" ")
while True: while True:
# display paginated share on the screen # display paginated share on the screen
await show_share_words(ctx, share_words, index) await show_share_words(share_words, index)
# make the user confirm words from the share # make the user confirm words from the share
if await _share_words_confirmed(ctx, index, share_words, len(shares)): if await _share_words_confirmed(index, share_words, len(shares)):
break # this share is confirmed, go to next one break # this share is confirmed, go to next one
async def slip39_advanced_show_and_confirm_shares( async def slip39_advanced_show_and_confirm_shares(
ctx: GenericContext, shares: Sequence[Sequence[str]] shares: Sequence[Sequence[str]],
) -> None: ) -> None:
# warn user about mnemonic safety # warn user about mnemonic safety
await show_backup_warning(ctx, True) await show_backup_warning(True)
for group_index, group in enumerate(shares): for group_index, group in enumerate(shares):
for share_index, share in enumerate(group): for share_index, share in enumerate(group):
share_words = share.split(" ") share_words = share.split(" ")
while True: while True:
# display paginated share on the screen # display paginated share on the screen
await show_share_words(ctx, share_words, share_index, group_index) await show_share_words(share_words, share_index, group_index)
# make the user confirm words from the share # make the user confirm words from the share
if await _share_words_confirmed( if await _share_words_confirmed(
ctx, share_index, share_words, len(group), group_index share_index, share_words, len(group), group_index
): ):
break # this share is confirmed, go to next one break # this share is confirmed, go to next one

@ -14,7 +14,6 @@ from apps.common.sdcard import ensure_sdcard
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable from typing import Awaitable
from trezor.messages import SdProtect from trezor.messages import SdProtect
from trezor.wire import Context
def _make_salt() -> tuple[bytes, bytes, bytes]: def _make_salt() -> tuple[bytes, bytes, bytes]:
@ -26,56 +25,54 @@ def _make_salt() -> tuple[bytes, bytes, bytes]:
return salt, auth_key, tag return salt, auth_key, tag
async def _set_salt( async def _set_salt(salt: bytes, salt_tag: bytes, stage: bool = False) -> None:
ctx: Context, salt: bytes, salt_tag: bytes, stage: bool = False
) -> None:
from apps.common.sdcard import confirm_retry_sd from apps.common.sdcard import confirm_retry_sd
while True: while True:
await ensure_sdcard(ctx) await ensure_sdcard()
try: try:
return storage_sd_salt.set_sd_salt(salt, salt_tag, stage) return storage_sd_salt.set_sd_salt(salt, salt_tag, stage)
except OSError: except OSError:
await confirm_retry_sd(ctx, ProcessError("SD card I/O error.")) await confirm_retry_sd(ProcessError("SD card I/O error."))
async def sd_protect(ctx: Context, msg: SdProtect) -> Success: async def sd_protect(msg: SdProtect) -> Success:
from trezor.wire import NotInitialized from trezor.wire import NotInitialized
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise NotInitialized("Device is not initialized") raise NotInitialized("Device is not initialized")
if msg.operation == SdProtectOperationType.ENABLE: if msg.operation == SdProtectOperationType.ENABLE:
return await _sd_protect_enable(ctx, msg) return await _sd_protect_enable(msg)
elif msg.operation == SdProtectOperationType.DISABLE: elif msg.operation == SdProtectOperationType.DISABLE:
return await _sd_protect_disable(ctx, msg) return await _sd_protect_disable(msg)
elif msg.operation == SdProtectOperationType.REFRESH: elif msg.operation == SdProtectOperationType.REFRESH:
return await _sd_protect_refresh(ctx, msg) return await _sd_protect_refresh(msg)
else: else:
raise ProcessError("Unknown operation") raise ProcessError("Unknown operation")
async def _sd_protect_enable(ctx: Context, msg: SdProtect) -> Success: async def _sd_protect_enable(msg: SdProtect) -> Success:
from apps.common.request_pin import request_pin from apps.common.request_pin import request_pin
if storage_sd_salt.is_enabled(): if storage_sd_salt.is_enabled():
raise ProcessError("SD card protection already enabled") raise ProcessError("SD card protection already enabled")
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(msg)
# Make sure SD card is present. # Make sure SD card is present.
await ensure_sdcard(ctx) await ensure_sdcard()
# Get the current PIN. # Get the current PIN.
if config.has_pin(): if config.has_pin():
pin = await request_pin(ctx, "Enter PIN", config.get_pin_rem()) pin = await request_pin("Enter PIN", config.get_pin_rem())
else: else:
pin = "" pin = ""
# Check PIN and prepare salt file. # Check PIN and prepare salt file.
salt, salt_auth_key, salt_tag = _make_salt() salt, salt_auth_key, salt_tag = _make_salt()
await _set_salt(ctx, salt, salt_tag) await _set_salt(salt, salt_tag)
if not config.change_pin(pin, pin, None, salt): if not config.change_pin(pin, pin, None, salt):
# Wrong PIN. Clean up the prepared salt file. # Wrong PIN. Clean up the prepared salt file.
@ -86,17 +83,15 @@ async def _sd_protect_enable(ctx: Context, msg: SdProtect) -> Success:
# SD-protection. If it fails for any reason, we suppress the # SD-protection. If it fails for any reason, we suppress the
# exception, because primarily we need to raise wire.PinInvalid. # exception, because primarily we need to raise wire.PinInvalid.
pass pass
await error_pin_invalid(ctx) await error_pin_invalid()
storage_device.set_sd_salt_auth_key(salt_auth_key) storage_device.set_sd_salt_auth_key(salt_auth_key)
await show_success( await show_success("success_sd", "You have successfully enabled SD protection.")
ctx, "success_sd", "You have successfully enabled SD protection."
)
return Success(message="SD card protection enabled") return Success(message="SD card protection enabled")
async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success: async def _sd_protect_disable(msg: SdProtect) -> Success:
if not storage_sd_salt.is_enabled(): if not storage_sd_salt.is_enabled():
raise ProcessError("SD card protection not enabled") raise ProcessError("SD card protection not enabled")
@ -104,14 +99,14 @@ async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success:
# protection. The cleanup will not happen in such case, but that does not matter. # protection. The cleanup will not happen in such case, but that does not matter.
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(msg)
# Get the current PIN and salt from the SD card. # Get the current PIN and salt from the SD card.
pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") pin, salt = await request_pin_and_sd_salt("Enter PIN")
# Check PIN and remove salt. # Check PIN and remove salt.
if not config.change_pin(pin, pin, salt, None): if not config.change_pin(pin, pin, salt, None):
await error_pin_invalid(ctx) await error_pin_invalid()
storage_device.set_sd_salt_auth_key(None) storage_device.set_sd_salt_auth_key(None)
@ -124,31 +119,29 @@ async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success:
# because overall SD-protection was successfully disabled. # because overall SD-protection was successfully disabled.
pass pass
await show_success( await show_success("success_sd", "You have successfully disabled SD protection.")
ctx, "success_sd", "You have successfully disabled SD protection."
)
return Success(message="SD card protection disabled") return Success(message="SD card protection disabled")
async def _sd_protect_refresh(ctx: Context, msg: SdProtect) -> Success: async def _sd_protect_refresh(msg: SdProtect) -> Success:
if not storage_sd_salt.is_enabled(): if not storage_sd_salt.is_enabled():
raise ProcessError("SD card protection not enabled") raise ProcessError("SD card protection not enabled")
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(msg)
# Make sure SD card is present. # Make sure SD card is present.
await ensure_sdcard(ctx) await ensure_sdcard()
# Get the current PIN and salt from the SD card. # Get the current PIN and salt from the SD card.
pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN") pin, old_salt = await request_pin_and_sd_salt("Enter PIN")
# Check PIN and change salt. # Check PIN and change salt.
new_salt, new_auth_key, new_salt_tag = _make_salt() new_salt, new_auth_key, new_salt_tag = _make_salt()
await _set_salt(ctx, new_salt, new_salt_tag, stage=True) await _set_salt(new_salt, new_salt_tag, stage=True)
if not config.change_pin(pin, pin, old_salt, new_salt): if not config.change_pin(pin, pin, old_salt, new_salt):
await error_pin_invalid(ctx) await error_pin_invalid()
storage_device.set_sd_salt_auth_key(new_auth_key) storage_device.set_sd_salt_auth_key(new_auth_key)
@ -161,13 +154,11 @@ async def _sd_protect_refresh(ctx: Context, msg: SdProtect) -> Success:
# SD-protection was successfully refreshed. # SD-protection was successfully refreshed.
pass pass
await show_success( await show_success("success_sd", "You have successfully refreshed SD protection.")
ctx, "success_sd", "You have successfully refreshed SD protection."
)
return Success(message="SD card protection refreshed") return Success(message="SD card protection refreshed")
def require_confirm_sd_protect(ctx: Context, msg: SdProtect) -> Awaitable[None]: def require_confirm_sd_protect(msg: SdProtect) -> Awaitable[None]:
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
if msg.operation == SdProtectOperationType.ENABLE: if msg.operation == SdProtectOperationType.ENABLE:
@ -179,4 +170,4 @@ def require_confirm_sd_protect(ctx: Context, msg: SdProtect) -> Awaitable[None]:
else: else:
raise ProcessError("Unknown operation") raise ProcessError("Unknown operation")
return confirm_action(ctx, "set_sd", "SD card protection", description=text) return confirm_action("set_sd", "SD card protection", description=text)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import SetU2FCounter, Success from trezor.messages import SetU2FCounter, Success
from trezor.wire import Context
async def set_u2f_counter(ctx: Context, msg: SetU2FCounter) -> Success: async def set_u2f_counter(msg: SetU2FCounter) -> Success:
import storage.device as storage_device import storage.device as storage_device
from trezor import wire from trezor import wire
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
@ -18,7 +17,6 @@ async def set_u2f_counter(ctx: Context, msg: SetU2FCounter) -> Success:
raise wire.ProcessError("No value provided") raise wire.ProcessError("No value provided")
await confirm_action( await confirm_action(
ctx,
"set_u2f_counter", "set_u2f_counter",
"Set U2F counter", "Set U2F counter",
description="Do you really want to set the U2F counter to {}?", description="Do you really want to set the U2F counter to {}?",

@ -2,16 +2,15 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ShowDeviceTutorial, Success from trezor.messages import ShowDeviceTutorial, Success
from trezor.wire import Context
async def show_tutorial(ctx: Context, msg: ShowDeviceTutorial) -> Success: async def show_tutorial(msg: ShowDeviceTutorial) -> Success:
from trezor.messages import Success from trezor.messages import Success
# NOTE: tutorial is defined only for TR, and this function should # NOTE: tutorial is defined only for TR, and this function should
# also be called only in case of TR # also be called only in case of TR
from trezor.ui.layouts import tutorial from trezor.ui.layouts import tutorial
await tutorial(ctx) await tutorial()
return Success(message="Tutorial shown") return Success(message="Tutorial shown")

@ -1,11 +1,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import GenericContext
from trezor.messages import WipeDevice, Success from trezor.messages import WipeDevice, Success
async def wipe_device(ctx: GenericContext, msg: WipeDevice) -> Success: async def wipe_device(msg: WipeDevice) -> Success:
import storage import storage
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.messages import Success from trezor.messages import Success
@ -14,7 +13,6 @@ async def wipe_device(ctx: GenericContext, msg: WipeDevice) -> Success:
from apps.base import reload_settings_from_storage from apps.base import reload_settings_from_storage
await confirm_action( await confirm_action(
ctx,
"confirm_wipe", "confirm_wipe",
"Wipe device", "Wipe device",
"All data will be erased.", "All data will be erased.",

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import CipherKeyValue, CipheredKeyValue from trezor.messages import CipherKeyValue, CipheredKeyValue
from trezor.wire import Context
# This module implements the SLIP-0011 symmetric encryption of key-value pairs using a # This module implements the SLIP-0011 symmetric encryption of key-value pairs using a
# deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0011.md. # deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0011.md.
async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValue: async def cipher_key_value(msg: CipherKeyValue) -> CipheredKeyValue:
from trezor.wire import DataError from trezor.wire import DataError
from trezor.messages import CipheredKeyValue from trezor.messages import CipheredKeyValue
from trezor.crypto import aes, hmac from trezor.crypto import aes, hmac
@ -16,7 +15,7 @@ async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValu
from apps.common.paths import AlwaysMatchingSchema from apps.common.paths import AlwaysMatchingSchema
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
keychain = await get_keychain(ctx, "secp256k1", [AlwaysMatchingSchema]) keychain = await get_keychain("secp256k1", [AlwaysMatchingSchema])
if len(msg.value) % 16 > 0: if len(msg.value) % 16 > 0:
raise DataError("Value length must be a multiple of 16") raise DataError("Value length must be a multiple of 16")
@ -35,9 +34,7 @@ async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValu
title = "Decrypt value" title = "Decrypt value"
verb = "CONFIRM" verb = "CONFIRM"
await confirm_action( await confirm_action("cipher_key_value", title, description=msg.key, verb=verb)
ctx, "cipher_key_value", title, description=msg.key, verb=verb
)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

@ -8,7 +8,6 @@ from apps.common.paths import PathSchema, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import CosiCommit from trezor.messages import CosiCommit
from trezor.wire import Context
# This module implements the cosigner part of the CoSi collective signatures # This module implements the cosigner part of the CoSi collective signatures
# as described in https://dedis.cs.yale.edu/dissent/papers/witness.pdf # as described in https://dedis.cs.yale.edu/dissent/papers/witness.pdf
@ -55,16 +54,17 @@ def _decode_path(address_n: list[int]) -> str | None:
return None return None
async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature: async def cosi_commit(msg: CosiCommit) -> CosiSignature:
import storage.cache as storage_cache import storage.cache as storage_cache
from trezor.crypto import cosi from trezor.crypto import cosi
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from trezor.ui.layouts import confirm_blob, confirm_text from trezor.ui.layouts import confirm_blob, confirm_text
from trezor.wire.context import call
from apps.common import paths from apps.common import paths
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
keychain = await get_keychain(ctx, "ed25519", [SCHEMA_SLIP18, SCHEMA_SLIP26]) keychain = await get_keychain("ed25519", [SCHEMA_SLIP18, SCHEMA_SLIP26])
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
seckey = node.private_key() seckey = node.private_key()
@ -78,7 +78,7 @@ async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature:
if commitment is None: if commitment is None:
raise RuntimeError raise RuntimeError
sign_msg = await ctx.call( sign_msg = await call(
CosiCommitment(commitment=commitment, pubkey=pubkey), CosiSign CosiCommitment(commitment=commitment, pubkey=pubkey), CosiSign
) )
@ -87,14 +87,12 @@ async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature:
path_description = _decode_path(sign_msg.address_n) path_description = _decode_path(sign_msg.address_n)
await confirm_text( await confirm_text(
ctx,
"cosi_confirm_key", "cosi_confirm_key",
"COSI KEYS", "COSI KEYS",
paths.address_n_to_str(sign_msg.address_n), paths.address_n_to_str(sign_msg.address_n),
path_description, path_description,
) )
await confirm_blob( await confirm_blob(
ctx,
"cosi_sign", "cosi_sign",
"COSI DATA", "COSI DATA",
sign_msg.data, sign_msg.data,

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetECDHSessionKey, ECDHSessionKey from trezor.messages import GetECDHSessionKey, ECDHSessionKey
from trezor.wire import Context
# This module implements the SLIP-0017 Elliptic Curve Diffie-Hellman algorithm, using a # This module implements the SLIP-0017 Elliptic Curve Diffie-Hellman algorithm, using a
# deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0017.md. # deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0017.md.
async def get_ecdh_session_key(ctx: Context, msg: GetECDHSessionKey) -> ECDHSessionKey: async def get_ecdh_session_key(msg: GetECDHSessionKey) -> ECDHSessionKey:
from trezor.ui.layouts import confirm_address from trezor.ui.layouts import confirm_address
from .sign_identity import ( from .sign_identity import (
get_identity_path, get_identity_path,
@ -24,13 +23,12 @@ async def get_ecdh_session_key(ctx: Context, msg: GetECDHSessionKey) -> ECDHSess
peer_public_key = msg.peer_public_key # local_cache_attribute peer_public_key = msg.peer_public_key # local_cache_attribute
curve_name = msg.ecdsa_curve_name or "secp256k1" curve_name = msg.ecdsa_curve_name or "secp256k1"
keychain = await get_keychain(ctx, curve_name, [AlwaysMatchingSchema]) keychain = await get_keychain(curve_name, [AlwaysMatchingSchema])
identity = serialize_identity(msg_identity) identity = serialize_identity(msg_identity)
# require_confirm_ecdh_session_key # require_confirm_ecdh_session_key
proto = msg_identity.proto.upper() if msg_identity.proto else "identity" proto = msg_identity.proto.upper() if msg_identity.proto else "identity"
await confirm_address( await confirm_address(
ctx,
f"Decrypt {proto}", f"Decrypt {proto}",
serialize_identity_without_proto(msg_identity), serialize_identity_without_proto(msg_identity),
None, None,

@ -1,18 +1,16 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import GetEntropy, Entropy from trezor.messages import GetEntropy, Entropy
async def get_entropy(ctx: Context, msg: GetEntropy) -> Entropy: async def get_entropy(msg: GetEntropy) -> Entropy:
from trezor.crypto import random from trezor.crypto import random
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.messages import Entropy from trezor.messages import Entropy
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
await confirm_action( await confirm_action(
ctx,
"get_entropy", "get_entropy",
"Confirm entropy", "Confirm entropy",
"Do you really want to send entropy?", "Do you really want to send entropy?",

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetFirmwareHash, FirmwareHash from trezor.messages import GetFirmwareHash, FirmwareHash
from trezor.wire import Context
from trezor.ui.layouts.common import ProgressLayout from trezor.ui.layouts.common import ProgressLayout
_progress_obj: ProgressLayout | None = None _progress_obj: ProgressLayout | None = None
async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash: async def get_firmware_hash(msg: GetFirmwareHash) -> FirmwareHash:
from trezor.messages import FirmwareHash from trezor.messages import FirmwareHash
from trezor.utils import firmware_hash from trezor.utils import firmware_hash
from trezor.ui.layouts.progress import progress from trezor.ui.layouts.progress import progress

@ -6,14 +6,13 @@ from apps.common import coininfo
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import IdentityType, SignIdentity, SignedIdentity from trezor.messages import IdentityType, SignIdentity, SignedIdentity
from trezor.wire import Context
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
# This module implements the SLIP-0013 authentication using a deterministic hierarchy, see # This module implements the SLIP-0013 authentication using a deterministic hierarchy, see
# https://github.com/satoshilabs/slips/blob/master/slip-0013.md. # https://github.com/satoshilabs/slips/blob/master/slip-0013.md.
async def sign_identity(ctx: Context, msg: SignIdentity) -> SignedIdentity: async def sign_identity(msg: SignIdentity) -> SignedIdentity:
from trezor.messages import SignedIdentity from trezor.messages import SignedIdentity
from trezor.ui.layouts import confirm_sign_identity from trezor.ui.layouts import confirm_sign_identity
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
@ -25,13 +24,13 @@ async def sign_identity(ctx: Context, msg: SignIdentity) -> SignedIdentity:
challenge_hidden = msg.challenge_hidden # local_cache_attribute challenge_hidden = msg.challenge_hidden # local_cache_attribute
curve_name = msg.ecdsa_curve_name or "secp256k1" curve_name = msg.ecdsa_curve_name or "secp256k1"
keychain = await get_keychain(ctx, curve_name, [AlwaysMatchingSchema]) keychain = await get_keychain(curve_name, [AlwaysMatchingSchema])
identity = serialize_identity(msg_identity) identity = serialize_identity(msg_identity)
# require_confirm_sign_identity # require_confirm_sign_identity
proto = msg_identity_proto.upper() if msg_identity_proto else "identity" proto = msg_identity_proto.upper() if msg_identity_proto else "identity"
await confirm_sign_identity( await confirm_sign_identity(
ctx, proto, serialize_identity_without_proto(msg_identity), challenge_visual proto, serialize_identity_without_proto(msg_identity), challenge_visual
) )
# END require_confirm_sign_identity # END require_confirm_sign_identity

@ -44,7 +44,7 @@ if __debug__:
return Failure(**kwargs) return Failure(**kwargs)
async def diag(ctx, msg, **kwargs) -> Failure: async def diag(msg, **kwargs) -> Failure:
ins = msg.ins # local_cache_attribute ins = msg.ins # local_cache_attribute
debug = log.debug # local_cache_attribute debug = log.debug # local_cache_attribute

@ -4,15 +4,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import MoneroGetAddress, MoneroAddress from trezor.messages import MoneroGetAddress, MoneroAddress
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_address( async def get_address(msg: MoneroGetAddress, keychain: Keychain) -> MoneroAddress:
ctx: Context, msg: MoneroGetAddress, keychain: Keychain
) -> MoneroAddress:
from trezor import wire from trezor import wire
from trezor.messages import MoneroAddress from trezor.messages import MoneroAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
@ -26,7 +23,7 @@ async def get_address(
minor = msg.minor # local_cache_attribute minor = msg.minor # local_cache_attribute
payment_id = msg.payment_id # local_cache_attribute payment_id = msg.payment_id # local_cache_attribute
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type) creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
addr = creds.address addr = creds.address
@ -69,7 +66,6 @@ async def get_address(
if msg.show_display: if msg.show_display:
await show_address( await show_address(
ctx,
addr, addr,
address_qr="monero:" + addr, address_qr="monero:" + addr,
path=paths.address_n_to_str(msg.address_n), path=paths.address_n_to_str(msg.address_n),

@ -24,12 +24,11 @@ _GET_TX_KEY_REASON_TX_DERIVATION = const(1)
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import MoneroGetTxKeyRequest, MoneroGetTxKeyAck from trezor.messages import MoneroGetTxKeyRequest, MoneroGetTxKeyAck
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_tx_keys( async def get_tx_keys(
ctx: Context, msg: MoneroGetTxKeyRequest, keychain: Keychain msg: MoneroGetTxKeyRequest, keychain: Keychain
) -> MoneroGetTxKeyAck: ) -> MoneroGetTxKeyAck:
from trezor import utils, wire from trezor import utils, wire
from trezor.messages import MoneroGetTxKeyAck from trezor.messages import MoneroGetTxKeyAck
@ -38,10 +37,10 @@ async def get_tx_keys(
from apps.monero import layout, misc from apps.monero import layout, misc
from apps.monero.xmr import chacha_poly, crypto, crypto_helpers from apps.monero.xmr import chacha_poly, crypto, crypto_helpers
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await layout.require_confirm_tx_key(ctx, export_key=not do_deriv) await layout.require_confirm_tx_key(export_key=not do_deriv)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type) creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -3,24 +3,21 @@ from typing import TYPE_CHECKING
from apps.common.keychain import auto_keychain from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import MoneroGetWatchKey, MoneroWatchKey from trezor.messages import MoneroGetWatchKey, MoneroWatchKey
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_watch_only( async def get_watch_only(msg: MoneroGetWatchKey, keychain: Keychain) -> MoneroWatchKey:
ctx: Context, msg: MoneroGetWatchKey, keychain: Keychain
) -> MoneroWatchKey:
from apps.common import paths from apps.common import paths
from apps.monero import layout, misc from apps.monero import layout, misc
from apps.monero.xmr import crypto_helpers from apps.monero.xmr import crypto_helpers
from trezor.messages import MoneroWatchKey from trezor.messages import MoneroWatchKey
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
await layout.require_confirm_watchkey(ctx) await layout.require_confirm_watchkey()
creds = misc.get_creds(keychain, msg.address_n, msg.network_type) creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
address = creds.address address = creds.address

@ -14,7 +14,6 @@ if TYPE_CHECKING:
MoneroKeyImageSyncStepRequest, MoneroKeyImageSyncStepRequest,
) )
from trezor.ui.layouts.common import ProgressLayout from trezor.ui.layouts.common import ProgressLayout
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -23,7 +22,7 @@ if TYPE_CHECKING:
@auto_keychain(__name__) @auto_keychain(__name__)
async def key_image_sync( async def key_image_sync(
ctx: Context, msg: MoneroKeyImageExportInitRequest, keychain: Keychain msg: MoneroKeyImageExportInitRequest, keychain: Keychain
) -> MoneroKeyImageSyncFinalAck: ) -> MoneroKeyImageSyncFinalAck:
import gc import gc
from trezor.messages import ( from trezor.messages import (
@ -31,16 +30,17 @@ async def key_image_sync(
MoneroKeyImageSyncFinalRequest, MoneroKeyImageSyncFinalRequest,
MoneroKeyImageSyncStepRequest, MoneroKeyImageSyncStepRequest,
) )
from trezor.wire.context import call
state = KeyImageSync() state = KeyImageSync()
res = await _init_step(state, ctx, msg, keychain) res = await _init_step(state, msg, keychain)
progress = layout.monero_keyimage_sync_progress() progress = layout.monero_keyimage_sync_progress()
while state.current_output + 1 < state.num_outputs: while state.current_output + 1 < state.num_outputs:
step = await ctx.call(res, MoneroKeyImageSyncStepRequest) step = await call(res, MoneroKeyImageSyncStepRequest)
res = _sync_step(state, ctx, step, progress) res = _sync_step(state, step, progress)
gc.collect() gc.collect()
await ctx.call(res, MoneroKeyImageSyncFinalRequest) await call(res, MoneroKeyImageSyncFinalRequest)
# _final_step # _final_step
if state.current_output + 1 != state.num_outputs: if state.current_output + 1 != state.num_outputs:
@ -66,7 +66,6 @@ class KeyImageSync:
async def _init_step( async def _init_step(
s: KeyImageSync, s: KeyImageSync,
ctx: Context,
msg: MoneroKeyImageExportInitRequest, msg: MoneroKeyImageExportInitRequest,
keychain: Keychain, keychain: Keychain,
) -> MoneroKeyImageExportInitAck: ) -> MoneroKeyImageExportInitAck:
@ -76,11 +75,11 @@ async def _init_step(
from apps.monero.xmr import monero from apps.monero.xmr import monero
from apps.monero import misc from apps.monero import misc
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
await layout.require_confirm_keyimage_sync(ctx) await layout.require_confirm_keyimage_sync()
s.num_outputs = msg.num s.num_outputs = msg.num
s.expected_hash = msg.hash s.expected_hash = msg.hash
@ -96,7 +95,6 @@ async def _init_step(
def _sync_step( def _sync_step(
s: KeyImageSync, s: KeyImageSync,
ctx: Context,
tds: MoneroKeyImageSyncStepRequest, tds: MoneroKeyImageSyncStepRequest,
progress: ProgressLayout, progress: ProgressLayout,
) -> MoneroKeyImageSyncStepAck: ) -> MoneroKeyImageSyncStepAck:

@ -17,7 +17,6 @@ if TYPE_CHECKING:
MoneroTransactionData, MoneroTransactionData,
MoneroTransactionDestinationEntry, MoneroTransactionDestinationEntry,
) )
from trezor.wire import Context
from .signing.state import State from .signing.state import State
@ -60,9 +59,8 @@ def _format_amount(value: int) -> str:
return f"{strings.format_amount(value, 12)} XMR" return f"{strings.format_amount(value, 12)} XMR"
async def require_confirm_watchkey(ctx: Context) -> None: async def require_confirm_watchkey() -> None:
await confirm_action( await confirm_action(
ctx,
"get_watchkey", "get_watchkey",
"Confirm export", "Confirm export",
description="Do you really want to export watch-only credentials?", description="Do you really want to export watch-only credentials?",
@ -70,9 +68,8 @@ async def require_confirm_watchkey(ctx: Context) -> None:
) )
async def require_confirm_keyimage_sync(ctx: Context) -> None: async def require_confirm_keyimage_sync() -> None:
await confirm_action( await confirm_action(
ctx,
"key_image_sync", "key_image_sync",
"Confirm ki sync", "Confirm ki sync",
description="Do you really want to\nsync key images?", description="Do you really want to\nsync key images?",
@ -80,9 +77,8 @@ async def require_confirm_keyimage_sync(ctx: Context) -> None:
) )
async def require_confirm_live_refresh(ctx: Context) -> None: async def require_confirm_live_refresh() -> None:
await confirm_action( await confirm_action(
ctx,
"live_refresh", "live_refresh",
"Confirm refresh", "Confirm refresh",
description="Do you really want to\nstart refresh?", description="Do you really want to\nstart refresh?",
@ -90,14 +86,13 @@ async def require_confirm_live_refresh(ctx: Context) -> None:
) )
async def require_confirm_tx_key(ctx: Context, export_key: bool = False) -> None: async def require_confirm_tx_key(export_key: bool = False) -> None:
description = ( description = (
"Do you really want to export tx_key?" "Do you really want to export tx_key?"
if export_key if export_key
else "Do you really want to export tx_der\nfor tx_proof?" else "Do you really want to export tx_der\nfor tx_proof?"
) )
await confirm_action( await confirm_action(
ctx,
"export_tx_key", "export_tx_key",
"Confirm export", "Confirm export",
description=description, description=description,
@ -106,7 +101,6 @@ async def require_confirm_tx_key(ctx: Context, export_key: bool = False) -> None
async def require_confirm_transaction( async def require_confirm_transaction(
ctx: Context,
state: State, state: State,
tsx_data: MoneroTransactionData, tsx_data: MoneroTransactionData,
network_type: MoneroNetworkType, network_type: MoneroNetworkType,
@ -122,7 +116,7 @@ async def require_confirm_transaction(
payment_id = tsx_data.payment_id # local_cache_attribute payment_id = tsx_data.payment_id # local_cache_attribute
if tsx_data.unlock_time != 0: if tsx_data.unlock_time != 0:
await _require_confirm_unlock_time(ctx, tsx_data.unlock_time) await _require_confirm_unlock_time(tsx_data.unlock_time)
for idx, dst in enumerate(outputs): for idx, dst in enumerate(outputs):
is_change = change_idx is not None and idx == change_idx is_change = change_idx is not None and idx == change_idx
@ -135,21 +129,20 @@ async def require_confirm_transaction(
cur_payment = payment_id cur_payment = payment_id
else: else:
cur_payment = None cur_payment = None
await _require_confirm_output(ctx, dst, network_type, cur_payment) await _require_confirm_output(dst, network_type, cur_payment)
if ( if (
payment_id payment_id
and not tsx_data.integrated_indices and not tsx_data.integrated_indices
and payment_id != DUMMY_PAYMENT_ID and payment_id != DUMMY_PAYMENT_ID
): ):
await _require_confirm_payment_id(ctx, payment_id) await _require_confirm_payment_id(payment_id)
await _require_confirm_fee(ctx, tsx_data.fee) await _require_confirm_fee(tsx_data.fee)
progress.step(state, 0) progress.step(state, 0)
async def _require_confirm_output( async def _require_confirm_output(
ctx: Context,
dst: MoneroTransactionDestinationEntry, dst: MoneroTransactionDestinationEntry,
network_type: MoneroNetworkType, network_type: MoneroNetworkType,
payment_id: bytes | None, payment_id: bytes | None,
@ -167,18 +160,16 @@ async def _require_confirm_output(
) )
await confirm_output( await confirm_output(
ctx,
addr, addr,
_format_amount(dst.amount), _format_amount(dst.amount),
br_code=BRT_SignTx, br_code=BRT_SignTx,
) )
async def _require_confirm_payment_id(ctx: Context, payment_id: bytes) -> None: async def _require_confirm_payment_id(payment_id: bytes) -> None:
from trezor.ui.layouts import confirm_blob from trezor.ui.layouts import confirm_blob
await confirm_blob( await confirm_blob(
ctx,
"confirm_payment_id", "confirm_payment_id",
"Payment ID", "Payment ID",
payment_id, payment_id,
@ -186,9 +177,8 @@ async def _require_confirm_payment_id(ctx: Context, payment_id: bytes) -> None:
) )
async def _require_confirm_fee(ctx: Context, fee: int) -> None: async def _require_confirm_fee(fee: int) -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_final", "confirm_final",
"Confirm fee", "Confirm fee",
"{}", "{}",
@ -197,9 +187,8 @@ async def _require_confirm_fee(ctx: Context, fee: int) -> None:
) )
async def _require_confirm_unlock_time(ctx: Context, unlock_time: int) -> None: async def _require_confirm_unlock_time(unlock_time: int) -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_locktime", "confirm_locktime",
"Confirm unlock time", "Confirm unlock time",
"Unlock time for this transaction is set to {}", "Unlock time for this transaction is set to {}",

@ -12,7 +12,6 @@ if TYPE_CHECKING:
MoneroLiveRefreshStartAck, MoneroLiveRefreshStartAck,
) )
from trezor.ui.layouts.common import ProgressLayout from trezor.ui.layouts.common import ProgressLayout
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .xmr.credentials import AccountCreds from .xmr.credentials import AccountCreds
@ -20,25 +19,26 @@ if TYPE_CHECKING:
@auto_keychain(__name__) @auto_keychain(__name__)
async def live_refresh( async def live_refresh(
ctx: Context, msg: MoneroLiveRefreshStartRequest, keychain: Keychain msg: MoneroLiveRefreshStartRequest, keychain: Keychain
) -> MoneroLiveRefreshFinalAck: ) -> MoneroLiveRefreshFinalAck:
import gc import gc
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.messages import MoneroLiveRefreshFinalAck, MoneroLiveRefreshStepRequest from trezor.messages import MoneroLiveRefreshFinalAck, MoneroLiveRefreshStepRequest
from trezor.wire.context import call_any
state = LiveRefreshState() state = LiveRefreshState()
res = await _init_step(state, ctx, msg, keychain) res = await _init_step(state, msg, keychain)
progress = layout.monero_live_refresh_progress() progress = layout.monero_live_refresh_progress()
while True: while True:
step = await ctx.call_any( step = await call_any(
res, res,
MessageType.MoneroLiveRefreshStepRequest, MessageType.MoneroLiveRefreshStepRequest,
MessageType.MoneroLiveRefreshFinalRequest, MessageType.MoneroLiveRefreshFinalRequest,
) )
del res del res
if MoneroLiveRefreshStepRequest.is_type_of(step): if MoneroLiveRefreshStepRequest.is_type_of(step):
res = _refresh_step(state, ctx, step, progress) res = _refresh_step(state, step, progress)
else: else:
return MoneroLiveRefreshFinalAck() return MoneroLiveRefreshFinalAck()
gc.collect() gc.collect()
@ -52,7 +52,6 @@ class LiveRefreshState:
async def _init_step( async def _init_step(
s: LiveRefreshState, s: LiveRefreshState,
ctx: Context,
msg: MoneroLiveRefreshStartRequest, msg: MoneroLiveRefreshStartRequest,
keychain: Keychain, keychain: Keychain,
) -> MoneroLiveRefreshStartAck: ) -> MoneroLiveRefreshStartAck:
@ -60,10 +59,10 @@ async def _init_step(
from apps.common import paths from apps.common import paths
from trezor.messages import MoneroLiveRefreshStartAck from trezor.messages import MoneroLiveRefreshStartAck
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get(storage_cache.APP_MONERO_LIVE_REFRESH): if not storage_cache.get(storage_cache.APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh(ctx) await layout.require_confirm_live_refresh()
storage_cache.set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01") storage_cache.set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
@ -73,7 +72,6 @@ async def _init_step(
def _refresh_step( def _refresh_step(
s: LiveRefreshState, s: LiveRefreshState,
ctx: Context,
msg: MoneroLiveRefreshStepRequest, msg: MoneroLiveRefreshStepRequest,
progress: ProgressLayout, progress: ProgressLayout,
) -> MoneroLiveRefreshStepAck: ) -> MoneroLiveRefreshStepAck:

@ -7,18 +7,16 @@ if TYPE_CHECKING:
from trezor.messages import MoneroTransactionFinalAck from trezor.messages import MoneroTransactionFinalAck
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.monero.signing.state import State from apps.monero.signing.state import State
from trezor.wire import Context
@auto_keychain(__name__) @auto_keychain(__name__)
async def sign_tx( async def sign_tx(received_msg, keychain: Keychain) -> MoneroTransactionFinalAck:
ctx: Context, received_msg, keychain: Keychain
) -> MoneroTransactionFinalAck:
import gc import gc
from trezor import log, utils from trezor import log, utils
from trezor.wire.context import get_context
from apps.monero.signing.state import State from apps.monero.signing.state import State
state = State(ctx) state = State()
mods = utils.unimport_begin() mods = utils.unimport_begin()
progress = MoneroTransactionProgress() progress = MoneroTransactionProgress()
@ -36,11 +34,12 @@ async def sign_tx(
if accept_msgs is None: if accept_msgs is None:
break break
ctx = get_context()
await ctx.write(result_msg) await ctx.write(result_msg)
del (result_msg, received_msg) del (result_msg, received_msg)
utils.unimport_end(mods) utils.unimport_end(mods)
received_msg = await ctx.read_any(accept_msgs) received_msg = await ctx.read(accept_msgs)
utils.unimport_end(mods) utils.unimport_end(mods)
return result_msg return result_msg

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from apps.monero.xmr.crypto import Point, Scalar from apps.monero.xmr.crypto import Point, Scalar
from apps.monero.xmr.credentials import AccountCreds from apps.monero.xmr.credentials import AccountCreds
from trezor.messages import MoneroTransactionDestinationEntry from trezor.messages import MoneroTransactionDestinationEntry
@ -19,20 +18,16 @@ class State:
STEP_ALL_OUT = 500 STEP_ALL_OUT = 500
STEP_SIGN = 600 STEP_SIGN = 600
def __init__(self, ctx: Context) -> None: def __init__(self) -> None:
from apps.monero.xmr.keccak_hasher import KeccakXmrArchive from apps.monero.xmr.keccak_hasher import KeccakXmrArchive
from apps.monero.xmr.mlsag_hasher import PreMlsagHasher from apps.monero.xmr.mlsag_hasher import PreMlsagHasher
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
self.ctx = ctx # Account credentials
# type: AccountCreds
""" # - view private/public key
Account credentials # - spend private/public key
type: AccountCreds # - and its corresponding address
- view private/public key
- spend private/public key
- and its corresponding address
"""
self.creds: AccountCreds | None = None self.creds: AccountCreds | None = None
# HMAC/encryption keys used to protect offloaded data # HMAC/encryption keys used to protect offloaded data

@ -37,7 +37,7 @@ async def init_transaction(
mem_trace = state.mem_trace # local_cache_attribute mem_trace = state.mem_trace # local_cache_attribute
outputs = tsx_data.outputs # local_cache_attribute outputs = tsx_data.outputs # local_cache_attribute
await paths.validate_path(state.ctx, keychain, address_n) await paths.validate_path(keychain, address_n)
state.creds = misc.get_creds(keychain, address_n, network_type) state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0 state.client_version = tsx_data.client_version or 0
@ -57,7 +57,6 @@ async def init_transaction(
# Ask for confirmation # Ask for confirmation
await layout.require_confirm_transaction( await layout.require_confirm_transaction(
state.ctx,
state, state,
tsx_data, tsx_data,
state.creds.network_type, state.creds.network_type,

@ -6,14 +6,11 @@ from . import CURVE, PATTERNS, SLIP44_ID
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.messages import NEMGetAddress, NEMAddress from trezor.messages import NEMGetAddress, NEMAddress
@with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE) @with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE)
async def get_address( async def get_address(msg: NEMGetAddress, keychain: Keychain) -> NEMAddress:
ctx: Context, msg: NEMGetAddress, keychain: Keychain
) -> NEMAddress:
from trezor.messages import NEMAddress from trezor.messages import NEMAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
from apps.common.paths import address_n_to_str, validate_path from apps.common.paths import address_n_to_str, validate_path
@ -24,14 +21,13 @@ async def get_address(
network = msg.network # local_cache_attribute network = msg.network # local_cache_attribute
validate_network(network) validate_network(network)
await validate_path(ctx, keychain, address_n, check_path(address_n, network)) await validate_path(keychain, address_n, check_path(address_n, network))
node = keychain.derive(address_n) node = keychain.derive(address_n)
address = node.nem_address(network) address = node.nem_address(network)
if msg.show_display: if msg.show_display:
await show_address( await show_address(
ctx,
address, address,
case_sensitive=False, case_sensitive=False,
path=address_n_to_str(address_n), path=address_n_to_str(address_n),

@ -1,18 +1,12 @@
from typing import TYPE_CHECKING
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.strings import format_amount from trezor.strings import format_amount
from trezor.ui.layouts import confirm_metadata from trezor.ui.layouts import confirm_metadata
from .helpers import NEM_MAX_DIVISIBILITY from .helpers import NEM_MAX_DIVISIBILITY
if TYPE_CHECKING:
from trezor.wire import Context
async def require_confirm_text(ctx: Context, action: str) -> None: async def require_confirm_text(action: str) -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_nem", "confirm_nem",
"Confirm action", "Confirm action",
action, action,
@ -20,9 +14,8 @@ async def require_confirm_text(ctx: Context, action: str) -> None:
) )
async def require_confirm_fee(ctx: Context, action: str, fee: int) -> None: async def require_confirm_fee(action: str, fee: int) -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_fee", "confirm_fee",
"Confirm fee", "Confirm fee",
action + "\n{}", action + "\n{}",
@ -31,21 +24,19 @@ async def require_confirm_fee(ctx: Context, action: str, fee: int) -> None:
) )
async def require_confirm_content(ctx: Context, headline: str, content: list) -> None: async def require_confirm_content(headline: str, content: list) -> None:
from trezor.ui.layouts import confirm_properties from trezor.ui.layouts import confirm_properties
await confirm_properties( await confirm_properties(
ctx,
"confirm_content", "confirm_content",
headline, headline,
content, content,
) )
async def require_confirm_final(ctx: Context, fee: int) -> None: async def require_confirm_final(fee: int) -> None:
# we use SignTx, not ConfirmOutput, for compatibility with T1 # we use SignTx, not ConfirmOutput, for compatibility with T1
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_final", "confirm_final",
"Final confirm", "Final confirm",
"Sign this transaction\n{}\nfor network fee?", "Sign this transaction\n{}\nfor network fee?",

@ -8,24 +8,21 @@ if TYPE_CHECKING:
NEMMosaicSupplyChange, NEMMosaicSupplyChange,
NEMTransactionCommon, NEMTransactionCommon,
) )
from trezor.wire import Context
async def mosaic_creation( async def mosaic_creation(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
creation: NEMMosaicCreation, creation: NEMMosaicCreation,
) -> bytes: ) -> bytes:
await layout.ask_mosaic_creation(ctx, common, creation) await layout.ask_mosaic_creation(common, creation)
return serialize.serialize_mosaic_creation(common, creation, public_key) return serialize.serialize_mosaic_creation(common, creation, public_key)
async def supply_change( async def supply_change(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
change: NEMMosaicSupplyChange, change: NEMMosaicSupplyChange,
) -> bytes: ) -> bytes:
await layout.ask_supply_change(ctx, common, change) await layout.ask_supply_change(common, change)
return serialize.serialize_mosaic_supply_change(common, change, public_key) return serialize.serialize_mosaic_supply_change(common, change, public_key)

@ -9,11 +9,10 @@ if TYPE_CHECKING:
NEMMosaicSupplyChange, NEMMosaicSupplyChange,
NEMTransactionCommon, NEMTransactionCommon,
) )
from trezor.wire import Context
async def ask_mosaic_creation( async def ask_mosaic_creation(
ctx: Context, common: NEMTransactionCommon, creation: NEMMosaicCreation common: NEMTransactionCommon, creation: NEMMosaicCreation
) -> None: ) -> None:
from ..layout import require_confirm_fee from ..layout import require_confirm_fee
@ -21,15 +20,15 @@ async def ask_mosaic_creation(
("Create mosaic", creation.definition.mosaic), ("Create mosaic", creation.definition.mosaic),
("under namespace", creation.definition.namespace), ("under namespace", creation.definition.namespace),
] ]
await require_confirm_content(ctx, "Create mosaic", creation_message) await require_confirm_content("Create mosaic", creation_message)
await _require_confirm_properties(ctx, creation.definition) await _require_confirm_properties(creation.definition)
await require_confirm_fee(ctx, "Confirm creation fee", creation.fee) await require_confirm_fee("Confirm creation fee", creation.fee)
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)
async def ask_supply_change( async def ask_supply_change(
ctx: Context, common: NEMTransactionCommon, change: NEMMosaicSupplyChange common: NEMTransactionCommon, change: NEMMosaicSupplyChange
) -> None: ) -> None:
from trezor.enums import NEMSupplyChangeType from trezor.enums import NEMSupplyChangeType
from ..layout import require_confirm_text from ..layout import require_confirm_text
@ -38,21 +37,19 @@ async def ask_supply_change(
("Modify supply for", change.mosaic), ("Modify supply for", change.mosaic),
("under namespace", change.namespace), ("under namespace", change.namespace),
] ]
await require_confirm_content(ctx, "Supply change", supply_message) await require_confirm_content("Supply change", supply_message)
if change.type == NEMSupplyChangeType.SupplyChange_Decrease: if change.type == NEMSupplyChangeType.SupplyChange_Decrease:
action = "Decrease" action = "Decrease"
elif change.type == NEMSupplyChangeType.SupplyChange_Increase: elif change.type == NEMSupplyChangeType.SupplyChange_Increase:
action = "Increase" action = "Increase"
else: else:
raise ValueError("Invalid supply change type") raise ValueError("Invalid supply change type")
await require_confirm_text(ctx, f"{action} supply by {change.delta} whole units?") await require_confirm_text(f"{action} supply by {change.delta} whole units?")
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)
async def _require_confirm_properties( async def _require_confirm_properties(definition: NEMMosaicDefinition) -> None:
ctx: Context, definition: NEMMosaicDefinition
) -> None:
from trezor.enums import NEMMosaicLevy from trezor.enums import NEMMosaicLevy
from trezor.ui.layouts import confirm_properties from trezor.ui.layouts import confirm_properties
@ -97,7 +94,6 @@ async def _require_confirm_properties(
append(("Levy type:", levy_type)) append(("Levy type:", levy_type))
await confirm_properties( await confirm_properties(
ctx,
"confirm_properties", "confirm_properties",
"Confirm properties", "Confirm properties",
properties, properties,

@ -8,11 +8,10 @@ if TYPE_CHECKING:
NEMSignTx, NEMSignTx,
NEMTransactionCommon, NEMTransactionCommon,
) )
from trezor.wire import Context
async def ask(ctx: Context, msg: NEMSignTx) -> None: async def ask(msg: NEMSignTx) -> None:
await layout.ask_multisig(ctx, msg) await layout.ask_multisig(msg)
def initiate(public_key: bytes, common: NEMTransactionCommon, inner_tx: bytes) -> bytes: def initiate(public_key: bytes, common: NEMTransactionCommon, inner_tx: bytes) -> bytes:
@ -26,13 +25,12 @@ def cosign(
async def aggregate_modification( async def aggregate_modification(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
aggr: NEMAggregateModification, aggr: NEMAggregateModification,
multisig: bool, multisig: bool,
) -> bytes: ) -> bytes:
await layout.ask_aggregate_modification(ctx, common, aggr, multisig) await layout.ask_aggregate_modification(common, aggr, multisig)
w = serialize.serialize_aggregate_modification(common, aggr, public_key) w = serialize.serialize_aggregate_modification(common, aggr, public_key)
for m in aggr.modifications: for m in aggr.modifications:

@ -8,24 +8,22 @@ if TYPE_CHECKING:
NEMSignTx, NEMSignTx,
NEMTransactionCommon, NEMTransactionCommon,
) )
from trezor.wire import Context
async def ask_multisig(ctx: Context, msg: NEMSignTx) -> None: async def ask_multisig(msg: NEMSignTx) -> None:
from ..layout import require_confirm_fee from ..layout import require_confirm_fee
assert msg.multisig is not None # sign_tx assert msg.multisig is not None # sign_tx
assert msg.multisig.signer is not None # sign_tx assert msg.multisig.signer is not None # sign_tx
address = nem.compute_address(msg.multisig.signer, msg.transaction.network) address = nem.compute_address(msg.multisig.signer, msg.transaction.network)
if msg.cosigning: if msg.cosigning:
await _require_confirm_address(ctx, "Cosign transaction for", address) await _require_confirm_address("Cosign transaction for", address)
else: else:
await _require_confirm_address(ctx, "Initiate transaction for", address) await _require_confirm_address("Initiate transaction for", address)
await require_confirm_fee(ctx, "Confirm multisig fee", msg.transaction.fee) await require_confirm_fee("Confirm multisig fee", msg.transaction.fee)
async def ask_aggregate_modification( async def ask_aggregate_modification(
ctx: Context,
common: NEMTransactionCommon, common: NEMTransactionCommon,
mod: NEMAggregateModification, mod: NEMAggregateModification,
multisig: bool, multisig: bool,
@ -34,7 +32,7 @@ async def ask_aggregate_modification(
from ..layout import require_confirm_final, require_confirm_text from ..layout import require_confirm_final, require_confirm_text
if not multisig: if not multisig:
await require_confirm_text(ctx, "Convert account to multisig account?") await require_confirm_text("Convert account to multisig account?")
for m in mod.modifications: for m in mod.modifications:
if m.type == NEMModificationType.CosignatoryModification_Add: if m.type == NEMModificationType.CosignatoryModification_Add:
@ -42,24 +40,23 @@ async def ask_aggregate_modification(
else: else:
action = "Remove" action = "Remove"
address = nem.compute_address(m.public_key, common.network) address = nem.compute_address(m.public_key, common.network)
await _require_confirm_address(ctx, action + " cosignatory", address) await _require_confirm_address(action + " cosignatory", address)
if mod.relative_change: if mod.relative_change:
if multisig: if multisig:
action = "Modify the number of cosignatories by " action = "Modify the number of cosignatories by "
else: else:
action = "Set minimum cosignatories to " action = "Set minimum cosignatories to "
await require_confirm_text(ctx, action + str(mod.relative_change) + "?") await require_confirm_text(action + str(mod.relative_change) + "?")
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)
async def _require_confirm_address(ctx: Context, action: str, address: str) -> None: async def _require_confirm_address(action: str, address: str) -> None:
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.ui.layouts import confirm_address from trezor.ui.layouts import confirm_address
await confirm_address( await confirm_address(
ctx,
"Confirm address", "Confirm address",
address, address,
action, action,

@ -2,16 +2,14 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from trezor.wire import Context
async def namespace( async def namespace(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
namespace: NEMProvisionNamespace, namespace: NEMProvisionNamespace,
) -> bytes: ) -> bytes:
from . import layout, serialize from . import layout, serialize
await layout.ask_provision_namespace(ctx, common, namespace) await layout.ask_provision_namespace(common, namespace)
return serialize.serialize_provision_namespace(common, namespace, public_key) return serialize.serialize_provision_namespace(common, namespace, public_key)

@ -2,11 +2,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from trezor.wire import Context
async def ask_provision_namespace( async def ask_provision_namespace(
ctx: Context, common: NEMTransactionCommon, namespace: NEMProvisionNamespace common: NEMTransactionCommon, namespace: NEMProvisionNamespace
) -> None: ) -> None:
from ..layout import ( from ..layout import (
require_confirm_content, require_confirm_content,
@ -19,11 +18,11 @@ async def ask_provision_namespace(
("Create namespace", namespace.namespace), ("Create namespace", namespace.namespace),
("under namespace", namespace.parent), ("under namespace", namespace.parent),
] ]
await require_confirm_content(ctx, "Confirm namespace", content) await require_confirm_content("Confirm namespace", content)
else: else:
content = [("Create namespace", namespace.namespace)] content = [("Create namespace", namespace.namespace)]
await require_confirm_content(ctx, "Confirm namespace", content) await require_confirm_content("Confirm namespace", content)
await require_confirm_fee(ctx, "Confirm rental fee", namespace.fee) await require_confirm_fee("Confirm rental fee", namespace.fee)
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)

@ -7,11 +7,10 @@ from . import CURVE, PATTERNS, SLIP44_ID
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import NEMSignTx, NEMSignedTx from trezor.messages import NEMSignTx, NEMSignedTx
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
@with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE) @with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE)
async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSignedTx: async def sign_tx(msg: NEMSignTx, keychain: Keychain) -> NEMSignedTx:
from trezor.wire import DataError from trezor.wire import DataError
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from trezor.messages import NEMSignedTx from trezor.messages import NEMSignedTx
@ -28,7 +27,6 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
transaction = msg.transaction # local_cache_attribute transaction = msg.transaction # local_cache_attribute
await validate_path( await validate_path(
ctx,
keychain, keychain,
transaction.address_n, transaction.address_n,
check_path(transaction.address_n, transaction.network), check_path(transaction.address_n, transaction.network),
@ -41,22 +39,21 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
raise DataError("No signer provided") raise DataError("No signer provided")
public_key = msg_multisig.signer public_key = msg_multisig.signer
common = msg_multisig common = msg_multisig
await multisig.ask(ctx, msg) await multisig.ask(msg)
else: else:
public_key = seed.remove_ed25519_prefix(node.public_key()) public_key = seed.remove_ed25519_prefix(node.public_key())
common = transaction common = transaction
if msg.transfer: if msg.transfer:
tx = await transfer.transfer(ctx, public_key, common, msg.transfer, node) tx = await transfer.transfer(public_key, common, msg.transfer, node)
elif msg.provision_namespace: elif msg.provision_namespace:
tx = await namespace.namespace(ctx, public_key, common, msg.provision_namespace) tx = await namespace.namespace(public_key, common, msg.provision_namespace)
elif msg.mosaic_creation: elif msg.mosaic_creation:
tx = await mosaic.mosaic_creation(ctx, public_key, common, msg.mosaic_creation) tx = await mosaic.mosaic_creation(public_key, common, msg.mosaic_creation)
elif msg.supply_change: elif msg.supply_change:
tx = await mosaic.supply_change(ctx, public_key, common, msg.supply_change) tx = await mosaic.supply_change(public_key, common, msg.supply_change)
elif msg.aggregate_modification: elif msg.aggregate_modification:
tx = await multisig.aggregate_modification( tx = await multisig.aggregate_modification(
ctx,
public_key, public_key,
common, common,
msg.aggregate_modification, msg.aggregate_modification,
@ -64,7 +61,7 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
) )
elif msg.importance_transfer: elif msg.importance_transfer:
tx = await transfer.importance_transfer( tx = await transfer.importance_transfer(
ctx, public_key, common, msg.importance_transfer public_key, common, msg.importance_transfer
) )
else: else:
raise DataError("No transaction provided") raise DataError("No transaction provided")

@ -4,12 +4,10 @@ from . import layout, serialize
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import NEMImportanceTransfer, NEMTransactionCommon, NEMTransfer from trezor.messages import NEMImportanceTransfer, NEMTransactionCommon, NEMTransfer
from trezor.wire import Context
from trezor.crypto import bip32 from trezor.crypto import bip32
async def transfer( async def transfer(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
transfer: NEMTransfer, transfer: NEMTransfer,
@ -18,7 +16,7 @@ async def transfer(
transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics) transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics)
payload, encrypted = serialize.get_transfer_payload(transfer, node) payload, encrypted = serialize.get_transfer_payload(transfer, node)
await layout.ask_transfer(ctx, common, transfer, encrypted) await layout.ask_transfer(common, transfer, encrypted)
w = serialize.serialize_transfer(common, transfer, public_key, payload, encrypted) w = serialize.serialize_transfer(common, transfer, public_key, payload, encrypted)
for mosaic in transfer.mosaics: for mosaic in transfer.mosaics:
@ -27,10 +25,9 @@ async def transfer(
async def importance_transfer( async def importance_transfer(
ctx: Context,
public_key: bytes, public_key: bytes,
common: NEMTransactionCommon, common: NEMTransactionCommon,
imp: NEMImportanceTransfer, imp: NEMImportanceTransfer,
) -> bytes: ) -> bytes:
await layout.ask_importance_transfer(ctx, common, imp) await layout.ask_importance_transfer(common, imp)
return serialize.serialize_importance_transfer(common, imp, public_key) return serialize.serialize_importance_transfer(common, imp, public_key)

@ -14,11 +14,9 @@ if TYPE_CHECKING:
NEMTransactionCommon, NEMTransactionCommon,
NEMTransfer, NEMTransfer,
) )
from trezor.wire import Context
async def ask_transfer( async def ask_transfer(
ctx: Context,
common: NEMTransactionCommon, common: NEMTransactionCommon,
transfer: NEMTransfer, transfer: NEMTransfer,
encrypted: bool, encrypted: bool,
@ -29,7 +27,6 @@ async def ask_transfer(
if transfer.payload: if transfer.payload:
# require_confirm_payload # require_confirm_payload
await confirm_text( await confirm_text(
ctx,
"confirm_payload", "confirm_payload",
"Confirm payload", "Confirm payload",
bytes(transfer.payload).decode(), bytes(transfer.payload).decode(),
@ -38,20 +35,19 @@ async def ask_transfer(
) )
for mosaic in transfer.mosaics: for mosaic in transfer.mosaics:
await _ask_transfer_mosaic(ctx, common, transfer, mosaic) await _ask_transfer_mosaic(common, transfer, mosaic)
# require_confirm_transfer # require_confirm_transfer
await confirm_output( await confirm_output(
ctx,
transfer.recipient, transfer.recipient,
f"{format_amount(_get_xem_amount(transfer), NEM_MAX_DIVISIBILITY)} XEM", f"{format_amount(_get_xem_amount(transfer), NEM_MAX_DIVISIBILITY)} XEM",
) )
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)
async def _ask_transfer_mosaic( async def _ask_transfer_mosaic(
ctx: Context, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic
) -> None: ) -> None:
from trezor.enums import NEMMosaicLevy from trezor.enums import NEMMosaicLevy
from trezor.ui.layouts import confirm_action, confirm_properties from trezor.ui.layouts import confirm_action, confirm_properties
@ -66,7 +62,6 @@ async def _ask_transfer_mosaic(
if definition: if definition:
await confirm_properties( await confirm_properties(
ctx,
"confirm_mosaic", "confirm_mosaic",
"Confirm mosaic", "Confirm mosaic",
( (
@ -93,7 +88,6 @@ async def _ask_transfer_mosaic(
) )
await confirm_properties( await confirm_properties(
ctx,
"confirm_mosaic_levy", "confirm_mosaic_levy",
"Confirm mosaic", "Confirm mosaic",
(("Confirm mosaic\nlevy fee of", levy_msg),), (("Confirm mosaic\nlevy fee of", levy_msg),),
@ -101,7 +95,6 @@ async def _ask_transfer_mosaic(
else: else:
await confirm_action( await confirm_action(
ctx,
"confirm_mosaic_unknown", "confirm_mosaic_unknown",
"Confirm mosaic", "Confirm mosaic",
"Unknown mosaic!", "Unknown mosaic!",
@ -110,7 +103,6 @@ async def _ask_transfer_mosaic(
) )
await confirm_properties( await confirm_properties(
ctx,
"confirm_mosaic_transfer", "confirm_mosaic_transfer",
"Confirm mosaic", "Confirm mosaic",
( (
@ -133,7 +125,7 @@ def _get_xem_amount(transfer: NEMTransfer) -> int:
async def ask_importance_transfer( async def ask_importance_transfer(
ctx: Context, common: NEMTransactionCommon, imp: NEMImportanceTransfer common: NEMTransactionCommon, imp: NEMImportanceTransfer
) -> None: ) -> None:
from trezor.enums import NEMImportanceTransferMode from trezor.enums import NEMImportanceTransferMode
from ..layout import require_confirm_text from ..layout import require_confirm_text
@ -142,5 +134,5 @@ async def ask_importance_transfer(
m = "Activate" m = "Activate"
else: else:
m = "Deactivate" m = "Deactivate"
await require_confirm_text(ctx, m + " remote harvesting?") await require_confirm_text(m + " remote harvesting?")
await require_confirm_final(ctx, common.fee) await require_confirm_final(common.fee)

@ -5,26 +5,23 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import RippleGetAddress, RippleAddress from trezor.messages import RippleGetAddress, RippleAddress
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_address( async def get_address(msg: RippleGetAddress, keychain: Keychain) -> RippleAddress:
ctx: Context, msg: RippleGetAddress, keychain: Keychain
) -> RippleAddress:
# NOTE: local imports here saves 20 bytes # NOTE: local imports here saves 20 bytes
from trezor.messages import RippleAddress from trezor.messages import RippleAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
from apps.common import paths from apps.common import paths
from .helpers import address_from_public_key from .helpers import address_from_public_key
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
pubkey = node.public_key() pubkey = node.public_key()
address = address_from_public_key(pubkey) address = address_from_public_key(pubkey)
if msg.show_display: if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(msg.address_n)) await show_address(address, path=paths.address_n_to_str(msg.address_n))
return RippleAddress(address=address) return RippleAddress(address=address)

@ -1,26 +1,19 @@
from typing import TYPE_CHECKING
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.strings import format_amount from trezor.strings import format_amount
from trezor.ui.layouts import confirm_metadata, confirm_total from trezor.ui.layouts import confirm_metadata, confirm_total
from .helpers import DECIMALS from .helpers import DECIMALS
if TYPE_CHECKING:
from trezor.wire import Context
async def require_confirm_total(ctx: Context, total: int, fee: int) -> None: async def require_confirm_total(total: int, fee: int) -> None:
await confirm_total( await confirm_total(
ctx,
format_amount(total, DECIMALS) + " XRP", format_amount(total, DECIMALS) + " XRP",
format_amount(fee, DECIMALS) + " XRP", format_amount(fee, DECIMALS) + " XRP",
) )
async def require_confirm_destination_tag(ctx: Context, tag: int) -> None: async def require_confirm_destination_tag(tag: int) -> None:
await confirm_metadata( await confirm_metadata(
ctx,
"confirm_destination_tag", "confirm_destination_tag",
"Confirm tag", "Confirm tag",
"Destination tag:\n{}", "Destination tag:\n{}",
@ -29,7 +22,7 @@ async def require_confirm_destination_tag(ctx: Context, tag: int) -> None:
) )
async def require_confirm_tx(ctx: Context, to: str, value: int) -> None: async def require_confirm_tx(to: str, value: int) -> None:
from trezor.ui.layouts import confirm_output from trezor.ui.layouts import confirm_output
await confirm_output(ctx, to, format_amount(value, DECIMALS) + " XRP") await confirm_output(to, format_amount(value, DECIMALS) + " XRP")

@ -5,14 +5,11 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import RippleSignTx, RippleSignedTx from trezor.messages import RippleSignTx, RippleSignedTx
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
# NOTE: it is one big function because that way it is the most flash-space-efficient # NOTE: it is one big function because that way it is the most flash-space-efficient
@auto_keychain(__name__) @auto_keychain(__name__)
async def sign_tx( async def sign_tx(msg: RippleSignTx, keychain: Keychain) -> RippleSignedTx:
ctx: Context, msg: RippleSignTx, keychain: Keychain
) -> RippleSignedTx:
from trezor.crypto import der from trezor.crypto import der
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha512 from trezor.crypto.hashlib import sha512
@ -26,7 +23,7 @@ async def sign_tx(
if payment.amount > helpers.MAX_ALLOWED_AMOUNT: if payment.amount > helpers.MAX_ALLOWED_AMOUNT:
raise ProcessError("Amount exceeds maximum allowed amount.") raise ProcessError("Amount exceeds maximum allowed amount.")
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key()) source_address = helpers.address_from_public_key(node.public_key())
@ -46,9 +43,9 @@ async def sign_tx(
raise ProcessError("Fee must be in the range of 10 to 10,000 drops") raise ProcessError("Fee must be in the range of 10 to 10,000 drops")
if payment.destination_tag is not None: if payment.destination_tag is not None:
await layout.require_confirm_destination_tag(ctx, payment.destination_tag) await layout.require_confirm_destination_tag(payment.destination_tag)
await layout.require_confirm_tx(ctx, payment.destination, payment.amount) await layout.require_confirm_tx(payment.destination, payment.amount)
await layout.require_confirm_total(ctx, payment.amount + msg.fee, msg.fee) await layout.require_confirm_total(payment.amount + msg.fee, msg.fee)
# Signs and encodes signature into DER format # Signs and encodes signature into DER format
first_half_of_sha512 = sha512(to_sign).digest()[:32] first_half_of_sha512 = sha512(to_sign).digest()[:32]

@ -4,20 +4,17 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import StellarGetAddress, StellarAddress from trezor.messages import StellarGetAddress, StellarAddress
from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_address( async def get_address(msg: StellarGetAddress, keychain: Keychain) -> StellarAddress:
ctx: Context, msg: StellarGetAddress, keychain: Keychain
) -> StellarAddress:
from apps.common import paths, seed from apps.common import paths, seed
from trezor.messages import StellarAddress from trezor.messages import StellarAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
from . import helpers from . import helpers
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
pubkey = seed.remove_ed25519_prefix(node.public_key()) pubkey = seed.remove_ed25519_prefix(node.public_key())
@ -25,6 +22,6 @@ async def get_address(
if msg.show_display: if msg.show_display:
path = paths.address_n_to_str(msg.address_n) path = paths.address_n_to_str(msg.address_n)
await show_address(ctx, address, case_sensitive=False, path=path) await show_address(address, case_sensitive=False, path=path)
return StellarAddress(address=address) return StellarAddress(address=address)

@ -7,21 +7,18 @@ from trezor.enums import ButtonRequestType
from . import consts from . import consts
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Context
from trezor.enums import StellarMemoType from trezor.enums import StellarMemoType
from trezor.messages import StellarAsset from trezor.messages import StellarAsset
async def require_confirm_init( async def require_confirm_init(
ctx: Context,
address: str, address: str,
network_passphrase: str, network_passphrase: str,
accounts_match: bool, accounts_match: bool,
) -> None: ) -> None:
description = "Initialize signing with" + " your account" if accounts_match else "" description = "Initialize signing with" + " your account" if accounts_match else ""
await layouts.confirm_address( await layouts.confirm_address(
ctx,
"Confirm Stellar", "Confirm Stellar",
address, address,
description, description,
@ -38,7 +35,6 @@ async def require_confirm_init(
if network: if network:
await layouts.confirm_metadata( await layouts.confirm_metadata(
ctx,
"confirm_init_network", "confirm_init_network",
"Confirm network", "Confirm network",
"Transaction is on {}", "Transaction is on {}",
@ -47,9 +43,8 @@ async def require_confirm_init(
) )
async def require_confirm_timebounds(ctx: Context, start: int, end: int) -> None: async def require_confirm_timebounds(start: int, end: int) -> None:
await layouts.confirm_properties( await layouts.confirm_properties(
ctx,
"confirm_timebounds", "confirm_timebounds",
"Confirm timebounds", "Confirm timebounds",
( (
@ -65,9 +60,7 @@ async def require_confirm_timebounds(ctx: Context, start: int, end: int) -> None
) )
async def require_confirm_memo( async def require_confirm_memo(memo_type: StellarMemoType, memo_text: str) -> None:
ctx: Context, memo_type: StellarMemoType, memo_text: str
) -> None:
from trezor.enums import StellarMemoType from trezor.enums import StellarMemoType
if memo_type == StellarMemoType.TEXT: if memo_type == StellarMemoType.TEXT:
@ -80,7 +73,6 @@ async def require_confirm_memo(
description = "Memo (RETURN)" description = "Memo (RETURN)"
else: else:
return await layouts.confirm_action( return await layouts.confirm_action(
ctx,
"confirm_memo", "confirm_memo",
"Confirm memo", "Confirm memo",
"No memo set!", "No memo set!",
@ -89,7 +81,6 @@ async def require_confirm_memo(
) )
await layouts.confirm_blob( await layouts.confirm_blob(
ctx,
"confirm_memo", "confirm_memo",
"Confirm memo", "Confirm memo",
memo_text, memo_text,
@ -97,10 +88,9 @@ async def require_confirm_memo(
) )
async def require_confirm_final(ctx: Context, fee: int, num_operations: int) -> None: async def require_confirm_final(fee: int, num_operations: int) -> None:
op_str = strings.format_plural("{count} {plural}", num_operations, "operation") op_str = strings.format_plural("{count} {plural}", num_operations, "operation")
await layouts.confirm_metadata( await layouts.confirm_metadata(
ctx,
"confirm_final", "confirm_final",
"Final confirm", "Final confirm",
"Sign this transaction made up of " + op_str + " and pay {}\nfor fee?", "Sign this transaction made up of " + op_str + " and pay {}\nfor fee?",

@ -2,11 +2,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.utils import Writer from trezor.utils import Writer
from trezor.wire import Context
from consts import StellarMessageType from consts import StellarMessageType
async def process_operation(ctx: Context, w: Writer, op: StellarMessageType) -> None: async def process_operation(w: Writer, op: StellarMessageType) -> None:
# Importing the stuff inside (only) function saves around 100 bytes here # Importing the stuff inside (only) function saves around 100 bytes here
# (probably because the local lookup is more efficient than a global lookup) # (probably because the local lookup is more efficient than a global lookup)
@ -17,48 +16,48 @@ async def process_operation(ctx: Context, w: Writer, op: StellarMessageType) ->
from . import layout, serialize from . import layout, serialize
if op.source_account: if op.source_account:
await layout.confirm_source_account(ctx, op.source_account) await layout.confirm_source_account(op.source_account)
serialize.write_account(w, op.source_account) serialize.write_account(w, op.source_account)
writers.write_uint32(w, consts.get_op_code(op)) writers.write_uint32(w, consts.get_op_code(op))
# NOTE: each branch below has 45 bytes (26 the actions, 19 the condition) # NOTE: each branch below has 45 bytes (26 the actions, 19 the condition)
if messages.StellarAccountMergeOp.is_type_of(op): if messages.StellarAccountMergeOp.is_type_of(op):
await layout.confirm_account_merge_op(ctx, op) await layout.confirm_account_merge_op(op)
serialize.write_account_merge_op(w, op) serialize.write_account_merge_op(w, op)
elif messages.StellarAllowTrustOp.is_type_of(op): elif messages.StellarAllowTrustOp.is_type_of(op):
await layout.confirm_allow_trust_op(ctx, op) await layout.confirm_allow_trust_op(op)
serialize.write_allow_trust_op(w, op) serialize.write_allow_trust_op(w, op)
elif messages.StellarBumpSequenceOp.is_type_of(op): elif messages.StellarBumpSequenceOp.is_type_of(op):
await layout.confirm_bump_sequence_op(ctx, op) await layout.confirm_bump_sequence_op(op)
serialize.write_bump_sequence_op(w, op) serialize.write_bump_sequence_op(w, op)
elif messages.StellarChangeTrustOp.is_type_of(op): elif messages.StellarChangeTrustOp.is_type_of(op):
await layout.confirm_change_trust_op(ctx, op) await layout.confirm_change_trust_op(op)
serialize.write_change_trust_op(w, op) serialize.write_change_trust_op(w, op)
elif messages.StellarCreateAccountOp.is_type_of(op): elif messages.StellarCreateAccountOp.is_type_of(op):
await layout.confirm_create_account_op(ctx, op) await layout.confirm_create_account_op(op)
serialize.write_create_account_op(w, op) serialize.write_create_account_op(w, op)
elif messages.StellarCreatePassiveSellOfferOp.is_type_of(op): elif messages.StellarCreatePassiveSellOfferOp.is_type_of(op):
await layout.confirm_create_passive_sell_offer_op(ctx, op) await layout.confirm_create_passive_sell_offer_op(op)
serialize.write_create_passive_sell_offer_op(w, op) serialize.write_create_passive_sell_offer_op(w, op)
elif messages.StellarManageDataOp.is_type_of(op): elif messages.StellarManageDataOp.is_type_of(op):
await layout.confirm_manage_data_op(ctx, op) await layout.confirm_manage_data_op(op)
serialize.write_manage_data_op(w, op) serialize.write_manage_data_op(w, op)
elif messages.StellarManageBuyOfferOp.is_type_of(op): elif messages.StellarManageBuyOfferOp.is_type_of(op):
await layout.confirm_manage_buy_offer_op(ctx, op) await layout.confirm_manage_buy_offer_op(op)
serialize.write_manage_buy_offer_op(w, op) serialize.write_manage_buy_offer_op(w, op)
elif messages.StellarManageSellOfferOp.is_type_of(op): elif messages.StellarManageSellOfferOp.is_type_of(op):
await layout.confirm_manage_sell_offer_op(ctx, op) await layout.confirm_manage_sell_offer_op(op)
serialize.write_manage_sell_offer_op(w, op) serialize.write_manage_sell_offer_op(w, op)
elif messages.StellarPathPaymentStrictReceiveOp.is_type_of(op): elif messages.StellarPathPaymentStrictReceiveOp.is_type_of(op):
await layout.confirm_path_payment_strict_receive_op(ctx, op) await layout.confirm_path_payment_strict_receive_op(op)
serialize.write_path_payment_strict_receive_op(w, op) serialize.write_path_payment_strict_receive_op(w, op)
elif messages.StellarPathPaymentStrictSendOp.is_type_of(op): elif messages.StellarPathPaymentStrictSendOp.is_type_of(op):
await layout.confirm_path_payment_strict_send_op(ctx, op) await layout.confirm_path_payment_strict_send_op(op)
serialize.write_path_payment_strict_send_op(w, op) serialize.write_path_payment_strict_send_op(w, op)
elif messages.StellarPaymentOp.is_type_of(op): elif messages.StellarPaymentOp.is_type_of(op):
await layout.confirm_payment_op(ctx, op) await layout.confirm_payment_op(op)
serialize.write_payment_op(w, op) serialize.write_payment_op(w, op)
elif messages.StellarSetOptionsOp.is_type_of(op): elif messages.StellarSetOptionsOp.is_type_of(op):
await layout.confirm_set_options_op(ctx, op) await layout.confirm_set_options_op(op)
serialize.write_set_options_op(w, op) serialize.write_set_options_op(w, op)
else: else:
raise ValueError("Unknown operation") raise ValueError("Unknown operation")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save