refactor(core): get rid of passing Context around

pull/3138/head
matejcik 11 months ago committed by matejcik
parent fe80793b47
commit 8c5c2f4204

@ -141,7 +141,7 @@ def get_features() -> Features:
return f
async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features:
async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
@ -170,20 +170,20 @@ async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features:
return features
async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features:
async def handle_GetFeatures(msg: GetFeatures) -> Features:
return get_features()
async def handle_Cancel(ctx: wire.Context, msg: Cancel) -> Success:
async def handle_Cancel(msg: Cancel) -> Success:
raise wire.ActionCancelled
async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success:
async def handle_LockDevice(msg: LockDevice) -> Success:
lock_device()
return Success()
async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success:
async def handle_SetBusy(msg: SetBusy) -> Success:
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
@ -199,24 +199,23 @@ async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success:
return Success()
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success:
async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session()
return Success()
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
async def handle_Ping(msg: Ping) -> Success:
if msg.button_protection:
from trezor.ui.layouts import confirm_action
from trezor.enums import ButtonRequestType as B
await confirm_action(ctx, "ping", "Confirm", "ping", br_code=B.ProtectCall)
await confirm_action("ping", "Confirm", "ping", br_code=B.ProtectCall)
return Success(message=msg.message)
async def handle_DoPreauthorized(
ctx: wire.Context, msg: DoPreauthorized
) -> protobuf.MessageType:
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context
from apps.common import authorization
if not authorization.is_set():
@ -225,22 +224,23 @@ async def handle_DoPreauthorized(
wire_types = authorization.get_wire_types()
utils.ensure(bool(wire_types), "Unsupported preauthorization found")
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
get_context().iface, req.MESSAGE_WIRE_TYPE
)
if handler is None:
return wire.unexpected_message()
return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments]
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.MessageType:
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.ui.layouts import confirm_action
from trezor.wire.context import call_any, get_context
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
from apps.common.writers import write_uint32_le
@ -253,7 +253,7 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
if msg.address_n != [SLIP25_PURPOSE]:
raise wire.DataError("Invalid path")
seed = await get_seed(ctx)
seed = await get_seed()
node = Slip21Node(seed)
node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
@ -269,7 +269,6 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
raise wire.DataError("Invalid MAC")
else:
await confirm_action(
ctx,
"confirm_coinjoin_access",
title="Coinjoin",
description="Access your coinjoin account?",
@ -277,19 +276,17 @@ async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.Mess
)
wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx)
req = await ctx.call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
get_context().iface, req.MESSAGE_WIRE_TYPE
)
assert handler is not None
return await handler(ctx, req, msg) # type: ignore [Expected 2 positional arguments]
return await handler(req, msg) # type: ignore [Expected 1 positional argument]
async def handle_CancelAuthorization(
ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType:
async def handle_CancelAuthorization(msg: CancelAuthorization) -> protobuf.MessageType:
from apps.common import authorization
authorization.clear()
@ -337,7 +334,7 @@ def lock_device_if_unlocked() -> None:
lock_device(interrupt_workflow=workflow.autolock_interrupts_workflow)
async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
async def unlock_device() -> None:
"""Ensure the device is in unlocked state.
If the storage is locked, attempt to unlock it. Reset the homescreen and the wire
@ -347,7 +344,7 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
if not config.is_unlocked():
# verify_user_pin will raise if the PIN was invalid
await verify_user_pin(ctx)
await verify_user_pin()
set_homescreen()
wire.find_handler = workflow_handlers.find_registered_handler
@ -369,9 +366,9 @@ def get_pinlocked_handler(
if msg_type in workflow.ALLOW_WHILE_LOCKED:
return orig_handler
async def wrapper(ctx: wire.Context, msg: wire.Msg) -> protobuf.MessageType:
await unlock_device(ctx)
return await orig_handler(ctx, msg)
async def wrapper(msg: wire.Msg) -> protobuf.MessageType:
await unlock_device()
return await orig_handler(msg)
return wrapper
@ -383,7 +380,7 @@ def reload_settings_from_storage() -> None:
workflow.idle_timer.set(
storage_device.get_autolock_delay_ms(), lock_device_if_unlocked
)
wire.experimental_enabled = storage_device.get_experimental_features()
wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features()
ui.display.orientation(storage_device.get_rotation())

@ -4,14 +4,11 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import BinanceGetAddress, BinanceAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def get_address(
ctx: Context, msg: BinanceGetAddress, keychain: Keychain
) -> BinanceAddress:
async def get_address(msg: BinanceGetAddress, keychain: Keychain) -> BinanceAddress:
from trezor.messages import BinanceAddress
from trezor.ui.layouts import show_address
@ -22,12 +19,12 @@ async def get_address(
HRP = "bnb"
address_n = msg.address_n # local_cache_attribute
await paths.validate_path(ctx, keychain, address_n)
await paths.validate_path(keychain, address_n)
node = keychain.derive(address_n)
pubkey = node.public_key()
address = address_from_public_key(pubkey, HRP)
if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n))
await show_address(address, path=paths.address_n_to_str(address_n))
return BinanceAddress(address=address)

@ -4,13 +4,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import BinanceGetPublicKey, BinancePublicKey
from trezor.wire import Context
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def get_public_key(
ctx: Context, msg: BinanceGetPublicKey, keychain: Keychain
msg: BinanceGetPublicKey, keychain: Keychain
) -> BinancePublicKey:
from ubinascii import hexlify
@ -19,11 +18,11 @@ async def get_public_key(
from apps.common import paths
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = node.public_key()
if msg.show_display:
await show_pubkey(ctx, hexlify(pubkey).decode())
await show_pubkey(hexlify(pubkey).decode())
return BinancePublicKey(public_key=pubkey)

@ -13,10 +13,9 @@ if TYPE_CHECKING:
BinanceOrderMsg,
BinanceTransferMsg,
)
from trezor.wire import Context
async def require_confirm_transfer(ctx: Context, msg: BinanceTransferMsg) -> None:
async def require_confirm_transfer(msg: BinanceTransferMsg) -> None:
items: list[tuple[str, str, str]] = []
def make_input_output_pages(msg: BinanceInputOutput, direction: str) -> None:
@ -35,19 +34,16 @@ async def require_confirm_transfer(ctx: Context, msg: BinanceTransferMsg) -> Non
for txoutput in msg.outputs:
make_input_output_pages(txoutput, "Confirm output")
await _confirm_transfer(ctx, items)
await _confirm_transfer(items)
async def _confirm_transfer(
ctx: Context, inputs_outputs: Sequence[tuple[str, str, str]]
) -> None:
async def _confirm_transfer(inputs_outputs: Sequence[tuple[str, str, str]]) -> None:
from trezor.ui.layouts import confirm_output
for index, (title, amount, address) in enumerate(inputs_outputs):
# Having hold=True on the last item
hold = index == len(inputs_outputs) - 1
await confirm_output(
ctx,
address,
amount,
title,
@ -55,9 +51,8 @@ async def _confirm_transfer(
)
async def require_confirm_cancel(ctx: Context, msg: BinanceCancelMsg) -> None:
async def require_confirm_cancel(msg: BinanceCancelMsg) -> None:
await confirm_properties(
ctx,
"confirm_cancel",
"Confirm cancel",
(
@ -70,7 +65,7 @@ async def require_confirm_cancel(ctx: Context, msg: BinanceCancelMsg) -> None:
)
async def require_confirm_order(ctx: Context, msg: BinanceOrderMsg) -> None:
async def require_confirm_order(msg: BinanceOrderMsg) -> None:
from trezor.enums import BinanceOrderSide
if msg.side == BinanceOrderSide.BUY:
@ -81,7 +76,6 @@ async def require_confirm_order(ctx: Context, msg: BinanceOrderMsg) -> None:
side = "Unknown"
await confirm_properties(
ctx,
"confirm_order",
"Confirm order",
(

@ -5,14 +5,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import BinanceSignTx, BinanceSignedTx
from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__)
async def sign_tx(
ctx: Context, envelope: BinanceSignTx, keychain: Keychain
) -> BinanceSignedTx:
async def sign_tx(envelope: BinanceSignTx, keychain: Keychain) -> BinanceSignedTx:
from trezor import wire
from trezor.wire.context import call_any
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.enums import MessageType
@ -32,12 +30,12 @@ async def sign_tx(
if envelope.msg_count > 1:
raise wire.DataError("Multiple messages not supported.")
await paths.validate_path(ctx, keychain, envelope.address_n)
await paths.validate_path(keychain, envelope.address_n)
node = keychain.derive(envelope.address_n)
tx_req = BinanceTxRequest()
msg = await ctx.call_any(
msg = await call_any(
tx_req,
MessageType.BinanceCancelMsg,
MessageType.BinanceOrderMsg,
@ -50,11 +48,11 @@ async def sign_tx(
msg_json = helpers.produce_json_for_signing(envelope, msg)
if BinanceTransferMsg.is_type_of(msg):
await layout.require_confirm_transfer(ctx, msg)
await layout.require_confirm_transfer(msg)
elif BinanceOrderMsg.is_type_of(msg):
await layout.require_confirm_order(ctx, msg)
await layout.require_confirm_order(msg)
elif BinanceCancelMsg.is_type_of(msg):
await layout.require_confirm_cancel(ctx, msg)
await layout.require_confirm_cancel(msg)
else:
raise wire.ProcessError("input message unrecognized")

@ -8,7 +8,6 @@ if TYPE_CHECKING:
from trezor.messages import AuthorizeCoinJoin, Success
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
from trezor.wire import Context
_MAX_COORDINATOR_LEN = const(36)
_MAX_ROUNDS = const(500)
@ -17,7 +16,7 @@ _MAX_COORDINATOR_FEE_RATE = 5 * pow(10, FEE_RATE_DECIMALS) # 5 %
@with_keychain
async def authorize_coinjoin(
ctx: Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
) -> Success:
from trezor.enums import ButtonRequestType
from trezor.messages import Success
@ -64,11 +63,10 @@ async def authorize_coinjoin(
msg.max_fee_per_kvbyte / 1000, coin, include_shortcut=True
)
await confirm_coinjoin(ctx, msg.max_rounds, max_fee_per_vbyte)
await confirm_coinjoin(msg.max_rounds, max_fee_per_vbyte)
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path(
ctx,
keychain,
validation_path,
address_n[0] == SLIP25_PURPOSE,
@ -79,7 +77,6 @@ async def authorize_coinjoin(
if msg.max_fee_per_kvbyte > coin.maxfee_kb:
await confirm_metadata(
ctx,
"fee_over_threshold",
"High mining fee",
"The mining fee of\n{}\nis unexpectedly high.",

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING:
from trezor.messages import GetAddress, HDNodeType, Address
from trezor import wire
from apps.common.keychain import Keychain
from apps.common.coininfo import CoinInfo
@ -30,9 +29,7 @@ def _get_xpubs(
@with_keychain
async def get_address(
ctx: wire.Context, msg: GetAddress, keychain: Keychain, coin: CoinInfo
) -> Address:
async def get_address(msg: GetAddress, keychain: Keychain, coin: CoinInfo) -> Address:
from trezor.enums import InputScriptType
from trezor.messages import Address
from trezor.ui.layouts import show_address
@ -51,7 +48,6 @@ async def get_address(
if msg.show_display:
# skip soft-validation for silent calls
await validate_path(
ctx,
keychain,
address_n,
validate_path_against_script_type(coin, msg),
@ -104,7 +100,6 @@ async def get_address(
multisig_index = multisig_pubkey_index(multisig, node.public_key())
await show_address(
ctx,
address_short,
case_sensitive=address_case_sensitive,
path=path,
@ -121,7 +116,6 @@ async def get_address(
else:
account = f"{coin.coin_shortcut} {account_name}"
await show_address(
ctx,
address_short,
address_qr=address,
case_sensitive=address_case_sensitive,

@ -4,14 +4,13 @@ from .keychain import with_keychain
if TYPE_CHECKING:
from trezor.messages import GetOwnershipId, OwnershipId
from trezor.wire import Context
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
@with_keychain
async def get_ownership_id(
ctx: Context, msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo
msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo
) -> OwnershipId:
from trezor.wire import DataError
from trezor.enums import InputScriptType
@ -26,7 +25,6 @@ async def get_ownership_id(
script_type = msg.script_type # local_cache_attribute
await validate_path(
ctx,
keychain,
msg.address_n,
validate_path_against_script_type(coin, msg),

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING:
from trezor.messages import GetOwnershipProof, OwnershipProof
from trezor.wire import Context
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
from .authorization import CoinJoinAuthorization
@ -12,7 +11,6 @@ if TYPE_CHECKING:
@with_keychain
async def get_ownership_proof(
ctx: Context,
msg: GetOwnershipProof,
keychain: Keychain,
coin: CoinInfo,
@ -37,7 +35,6 @@ async def get_ownership_proof(
raise ProcessError("Unauthorized operation")
else:
await validate_path(
ctx,
keychain,
msg.address_n,
validate_path_against_script_type(coin, msg),
@ -71,14 +68,12 @@ async def get_ownership_proof(
# In order to set the "user confirmation" bit in the proof, the user must actually confirm.
if msg.user_confirmation and not authorization:
await confirm_action(
ctx,
"confirm_ownership_proof",
"Proof of ownership",
description="Do you want to create a proof of ownership?",
)
if msg.commitment_data:
await confirm_blob(
ctx,
"confirm_ownership_proof",
"Proof of ownership",
msg.commitment_data,

@ -3,11 +3,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import GetPublicKey, PublicKey
from trezor.protobuf import MessageType
from trezor.wire import Context
async def get_public_key(
ctx: Context, msg: GetPublicKey, auth_msg: MessageType | None = None
msg: GetPublicKey, auth_msg: MessageType | None = None
) -> PublicKey:
from trezor import wire
from trezor.enums import InputScriptType
@ -33,7 +32,7 @@ async def get_public_key(
if auth_msg.address_n != address_n[: len(auth_msg.address_n)]:
raise FORBIDDEN_KEY_PATH
keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema])
keychain = await get_keychain(curve_name, [paths.AlwaysMatchingSchema])
node = keychain.derive(address_n)
@ -82,7 +81,7 @@ async def get_public_key(
if msg.show_display:
from trezor.ui.layouts import show_xpub
await show_xpub(ctx, node_xpub, "XPUB")
await show_xpub(node_xpub, "XPUB")
return PublicKey(
node=node_type,

@ -14,7 +14,6 @@ if TYPE_CHECKING:
from typing_extensions import Protocol
from trezor.protobuf import MessageType
from trezor.wire import Context
from trezor.messages import (
GetAddress,
@ -265,7 +264,6 @@ def _get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
async def _get_keychain_for_coin(
ctx: Context,
coin: coininfo.CoinInfo,
unlock_schemas: Iterable[PathSchema] = (),
) -> Keychain:
@ -273,7 +271,7 @@ async def _get_keychain_for_coin(
schemas = _get_schemas_for_coin(coin, unlock_schemas)
slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]]
keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces)
keychain = await get_keychain(coin.curve_name, schemas, slip21_namespaces)
return keychain
@ -318,19 +316,18 @@ def _get_unlock_schemas(
def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(
ctx: Context,
msg: MsgIn,
auth_msg: MessageType | None = None,
) -> MsgOut:
coin = _get_coin_by_name(msg.coin_name)
unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin)
keychain = await _get_keychain_for_coin(ctx, coin, unlock_schemas)
keychain = await _get_keychain_for_coin(coin, unlock_schemas)
if AuthorizeCoinJoin.is_type_of(auth_msg):
auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj)
return await func(msg, keychain, coin, auth_obj)
else:
with keychain:
return await func(ctx, msg, keychain, coin)
return await func(msg, keychain, coin)
return wrapper

@ -4,7 +4,6 @@ from .keychain import with_keychain
if TYPE_CHECKING:
from trezor.messages import SignMessage, MessageSignature
from trezor.wire import Context
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
@ -12,7 +11,7 @@ if TYPE_CHECKING:
@with_keychain
async def sign_message(
ctx: Context, msg: SignMessage, keychain: Keychain, coin: CoinInfo
msg: SignMessage, keychain: Keychain, coin: CoinInfo
) -> MessageSignature:
from trezor import wire
from trezor.crypto.curve import secp256k1
@ -31,13 +30,12 @@ async def sign_message(
script_type = msg.script_type or InputScriptType.SPENDADDRESS
await validate_path(
ctx, keychain, address_n, validate_path_against_script_type(coin, msg)
keychain, address_n, validate_path_against_script_type(coin, msg)
)
node = keychain.derive(address_n)
address = get_address(script_type, coin, node)
await confirm_signverify(
ctx,
coin.coin_shortcut,
decode_message(message),
address_short(coin, address),

@ -11,7 +11,6 @@ if not utils.BITCOIN_ONLY:
if TYPE_CHECKING:
from typing import Protocol
from trezor.wire import Context
from trezor.messages import (
SignTx,
TxAckInput,
@ -54,7 +53,6 @@ if TYPE_CHECKING:
@with_keychain
async def sign_tx(
ctx: Context,
msg: SignTx,
keychain: Keychain,
coin: CoinInfo,
@ -62,6 +60,7 @@ async def sign_tx(
) -> TxRequest:
from trezor.enums import RequestType
from trezor.messages import TxRequest
from trezor.wire.context import call
from ..common import BITCOIN_NAMES
from . import approvers, bitcoin, helpers, progress
@ -93,9 +92,9 @@ async def sign_tx(
assert TxRequest.is_type_of(req)
if req.request_type == RequestType.TXFINISHED:
return req
res = await ctx.call(req, request_class)
res = await call(req, request_class)
elif isinstance(req, helpers.UiConfirm):
res = await req.confirm_dialog(ctx)
res = await req.confirm_dialog()
progress.progress.report_init()
else:
raise TypeError("Invalid signing instruction")

@ -11,7 +11,6 @@ from . import layout
if TYPE_CHECKING:
from typing import Any, Awaitable
from trezor.enums import AmountUnit
from trezor.wire import Context
from trezor.messages import (
PrevInput,
@ -31,7 +30,7 @@ if TYPE_CHECKING:
class UiConfirm:
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
raise NotImplementedError
__eq__ = utils.obj_eq
@ -50,9 +49,8 @@ class UiConfirmOutput(UiConfirm):
self.amount_unit = amount_unit
self.output_index = output_index
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_output(
ctx,
self.output,
self.coin,
self.amount_unit,
@ -66,9 +64,9 @@ class UiConfirmDecredSSTXSubmission(UiConfirm):
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_decred_sstx_submission(
ctx, self.output, self.coin, self.amount_unit
self.output, self.coin, self.amount_unit
)
@ -83,9 +81,9 @@ class UiConfirmPaymentRequest(UiConfirm):
self.amount_unit = amount_unit
self.coin = coin
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_payment_request(
ctx, self.payment_req, self.coin, self.amount_unit
self.payment_req, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
@ -96,8 +94,8 @@ class UiConfirmReplacement(UiConfirm):
self.title = title
self.txid = txid
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_replacement(ctx, self.title, self.txid)
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_replacement(self.title, self.txid)
class UiConfirmModifyOutput(UiConfirm):
@ -113,9 +111,9 @@ class UiConfirmModifyOutput(UiConfirm):
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_modify_output(
ctx, self.txo, self.orig_txo, self.coin, self.amount_unit
self.txo, self.orig_txo, self.coin, self.amount_unit
)
@ -136,9 +134,8 @@ class UiConfirmModifyFee(UiConfirm):
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_modify_fee(
ctx,
self.title,
self.user_fee_change,
self.total_fee_new,
@ -165,9 +162,8 @@ class UiConfirmTotal(UiConfirm):
self.amount_unit = amount_unit
self.address_n = address_n
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_total(
ctx,
self.spending,
self.fee,
self.fee_rate,
@ -186,9 +182,9 @@ class UiConfirmJointTotal(UiConfirm):
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_joint_total(
ctx, self.spending, self.total, self.coin, self.amount_unit
self.spending, self.total, self.coin, self.amount_unit
)
@ -198,33 +194,31 @@ class UiConfirmFeeOverThreshold(UiConfirm):
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_feeoverthreshold(
ctx, self.fee, self.coin, self.amount_unit
)
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_feeoverthreshold(self.fee, self.coin, self.amount_unit)
class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int):
self.change_count = change_count
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_change_count_over_threshold(ctx, self.change_count)
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_change_count_over_threshold(self.change_count)
class UiConfirmUnverifiedExternalInput(UiConfirm):
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_unverified_external_input(ctx)
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_unverified_external_input()
class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list):
self.address_n = address_n
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
from apps.common import paths
return paths.show_path_warning(ctx, self.address_n)
return paths.show_path_warning(self.address_n)
class UiConfirmNonDefaultLocktime(UiConfirm):
@ -232,9 +226,9 @@ class UiConfirmNonDefaultLocktime(UiConfirm):
self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled
def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
def confirm_dialog(self) -> Awaitable[Any]:
return layout.confirm_nondefault_locktime(
ctx, self.lock_time, self.lock_time_disabled
self.lock_time, self.lock_time_disabled
)

@ -22,7 +22,6 @@ if TYPE_CHECKING:
from trezor.messages import TxAckPaymentRequest, TxOutput
from trezor.ui.layouts import LayoutType
from trezor.enums import AmountUnit
from trezor.wire import Context
from apps.common.coininfo import CoinInfo
from apps.common.paths import Bip32Path
@ -59,7 +58,6 @@ def account_label(coin: CoinInfo, address_n: Bip32Path | None) -> str:
async def confirm_output(
ctx: Context,
output: TxOutput,
coin: CoinInfo,
amount_unit: AmountUnit,
@ -74,7 +72,6 @@ async def confirm_output(
if omni.is_valid(data):
# OMNI transaction
layout: LayoutType = confirm_metadata(
ctx,
"omni_transaction",
"OMNI transaction",
omni.parse(data),
@ -84,7 +81,6 @@ async def confirm_output(
else:
# generic OP_RETURN
layout = layouts.confirm_blob(
ctx,
"op_return",
"OP_RETURN",
data,
@ -107,7 +103,6 @@ async def confirm_output(
)
layout = layouts.confirm_output(
ctx,
address_short,
format_coin_amount(output.amount, coin, amount_unit),
title=title,
@ -119,14 +114,13 @@ async def confirm_output(
async def confirm_decred_sstx_submission(
ctx: Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None:
assert output.address is not None
address_short = addresses.address_short(coin, output.address)
amount = format_coin_amount(output.amount, coin, amount_unit)
await layouts.confirm_value(
ctx,
"Purchase ticket",
amount,
"Ticket amount:",
@ -136,7 +130,6 @@ async def confirm_decred_sstx_submission(
)
await layouts.confirm_value(
ctx,
"Purchase ticket",
address_short,
"Voting rights to:",
@ -147,7 +140,6 @@ async def confirm_decred_sstx_submission(
async def confirm_payment_request(
ctx: Context,
msg: TxAckPaymentRequest,
coin: CoinInfo,
amount_unit: AmountUnit,
@ -168,25 +160,22 @@ async def confirm_payment_request(
assert msg.amount is not None
return await layouts.confirm_payment_request(
ctx,
msg.recipient_name,
format_coin_amount(msg.amount, coin, amount_unit),
memo_texts,
)
async def confirm_replacement(ctx: Context, title: str, txid: bytes) -> None:
async def confirm_replacement(title: str, txid: bytes) -> None:
from ubinascii import hexlify
await layouts.confirm_replacement(
ctx,
title,
hexlify(txid).decode(),
)
async def confirm_modify_output(
ctx: Context,
txo: TxOutput,
orig_txo: TxOutput,
coin: CoinInfo,
@ -196,7 +185,6 @@ async def confirm_modify_output(
address_short = addresses.address_short(coin, txo.address)
amount_change = txo.amount - orig_txo.amount
await layouts.confirm_modify_output(
ctx,
address_short,
amount_change,
format_coin_amount(abs(amount_change), coin, amount_unit),
@ -205,7 +193,6 @@ async def confirm_modify_output(
async def confirm_modify_fee(
ctx: Context,
title: str,
user_fee_change: int,
total_fee_new: int,
@ -214,7 +201,6 @@ async def confirm_modify_fee(
amount_unit: AmountUnit,
) -> None:
await layouts.confirm_modify_fee(
ctx,
title,
user_fee_change,
format_coin_amount(abs(user_fee_change), coin, amount_unit),
@ -224,21 +210,18 @@ async def confirm_modify_fee(
async def confirm_joint_total(
ctx: Context,
spending: int,
total: int,
coin: CoinInfo,
amount_unit: AmountUnit,
) -> None:
await layouts.confirm_joint_total(
ctx,
spending_amount=format_coin_amount(spending, coin, amount_unit),
total_amount=format_coin_amount(total, coin, amount_unit),
)
async def confirm_total(
ctx: Context,
spending: int,
fee: int,
fee_rate: float,
@ -248,7 +231,6 @@ async def confirm_total(
) -> None:
await layouts.confirm_total(
ctx,
format_coin_amount(spending, coin, amount_unit),
format_coin_amount(fee, coin, amount_unit),
fee_rate_amount=format_fee_rate(fee_rate, coin) if fee_rate >= 0 else None,
@ -257,11 +239,10 @@ async def confirm_total(
async def confirm_feeoverthreshold(
ctx: Context, fee: int, coin: CoinInfo, amount_unit: AmountUnit
fee: int, coin: CoinInfo, amount_unit: AmountUnit
) -> None:
fee_amount = format_coin_amount(fee, coin, amount_unit)
await layouts.show_warning(
ctx,
"fee_over_threshold",
"Unusually high fee.",
fee_amount,
@ -269,9 +250,8 @@ async def confirm_feeoverthreshold(
)
async def confirm_change_count_over_threshold(ctx: Context, change_count: int) -> None:
async def confirm_change_count_over_threshold(change_count: int) -> None:
await layouts.show_warning(
ctx,
"change_count_over_threshold",
"A lot of change-outputs.",
f"{str(change_count)} outputs",
@ -279,9 +259,8 @@ async def confirm_change_count_over_threshold(ctx: Context, change_count: int) -
)
async def confirm_unverified_external_input(ctx: Context) -> None:
async def confirm_unverified_external_input() -> None:
await layouts.show_warning(
ctx,
"unverified_external_input",
"The transaction contains unverified external inputs.",
"Proceed anyway?",
@ -290,14 +269,11 @@ async def confirm_unverified_external_input(ctx: Context) -> None:
)
async def confirm_nondefault_locktime(
ctx: Context, lock_time: int, lock_time_disabled: bool
) -> None:
async def confirm_nondefault_locktime(lock_time: int, lock_time_disabled: bool) -> None:
from trezor.strings import format_timestamp
if lock_time_disabled:
await layouts.show_warning(
ctx,
"nondefault_locktime",
"Locktime is set but will have no effect.",
"Proceed anyway?",
@ -312,7 +288,6 @@ async def confirm_nondefault_locktime(
text = "Locktime for this transaction is set to:"
value = format_timestamp(lock_time)
await layouts.confirm_value(
ctx,
"Confirm locktime",
value,
text,

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo
from trezor.messages import VerifyMessage, Success
from trezor.wire import Context
from trezor.enums import InputScriptType
@ -48,7 +47,7 @@ def _address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
raise DataError("Invalid address")
async def verify_message(ctx: Context, msg: VerifyMessage) -> Success:
async def verify_message(msg: VerifyMessage) -> Success:
from trezor import utils
from trezor.wire import ProcessError
from trezor.crypto.curve import secp256k1
@ -109,12 +108,11 @@ async def verify_message(ctx: Context, msg: VerifyMessage) -> Success:
raise ProcessError("Invalid signature")
await confirm_signverify(
ctx,
coin.coin_shortcut,
decode_message(message),
address_short(coin, address),
verify=True,
)
await show_success(ctx, "verify_message", "The signature is valid.")
await show_success("verify_message", "The signature is valid.")
return Success(message="Message verified")

@ -16,7 +16,6 @@ if TYPE_CHECKING:
SignedCVoteRegistrationPayload = tuple[CVoteRegistrationPayload, bytes]
from trezor import messages
from trezor.wire import Context
from . import seed
@ -115,7 +114,6 @@ def _get_voting_purpose_to_serialize(
async def show(
ctx: Context,
keychain: seed.Keychain,
auxiliary_data_hash: bytes,
parameters: messages.CardanoCVoteRegistrationParametersType | None,
@ -125,7 +123,6 @@ async def show(
) -> None:
if parameters:
await _show_cvote_registration(
ctx,
keychain,
parameters,
protocol_magic,
@ -134,7 +131,7 @@ async def show(
)
if should_show_details:
await layout.show_auxiliary_data_hash(ctx, auxiliary_data_hash)
await layout.show_auxiliary_data_hash(auxiliary_data_hash)
def _should_show_payment_warning(address_type: CardanoAddressType) -> bool:
@ -146,7 +143,6 @@ def _should_show_payment_warning(address_type: CardanoAddressType) -> bool:
async def _show_cvote_registration(
ctx: Context,
keychain: seed.Keychain,
parameters: messages.CardanoCVoteRegistrationParametersType,
protocol_magic: int,
@ -161,7 +157,7 @@ async def _show_cvote_registration(
bech32.HRP_CVOTE_PUBLIC_KEY, delegation.vote_public_key
)
await layout.confirm_cvote_registration_delegation(
ctx, encoded_public_key, delegation.weight
encoded_public_key, delegation.weight
)
if parameters.payment_address:
@ -169,7 +165,7 @@ async def _show_cvote_registration(
addresses.get_type(addresses.get_bytes_unsafe(parameters.payment_address))
)
await layout.confirm_cvote_registration_payment_address(
ctx, parameters.payment_address, show_payment_warning
parameters.payment_address, show_payment_warning
)
else:
address_parameters = parameters.payment_address_parameters
@ -179,7 +175,6 @@ async def _show_cvote_registration(
address_parameters.address_type
)
await layout.show_cvote_registration_payment_credentials(
ctx,
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
show_both_credentials,
@ -197,7 +192,6 @@ async def _show_cvote_registration(
)
await layout.confirm_cvote_registration(
ctx,
encoded_public_key,
parameters.staking_path,
parameters.nonce,

@ -3,13 +3,12 @@ from typing import TYPE_CHECKING
from . import seed
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetAddress, CardanoAddress
@seed.with_keychain
async def get_address(
ctx: Context, msg: CardanoGetAddress, keychain: seed.Keychain
msg: CardanoGetAddress, keychain: seed.Keychain
) -> CardanoAddress:
from trezor.messages import CardanoAddress
from trezor import log, wire
@ -36,10 +35,9 @@ async def get_address(
# _display_address
if should_show_credentials(address_parameters):
await show_credentials(
ctx,
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
)
await show_cardano_address(ctx, address_parameters, address, msg.protocol_magic)
await show_cardano_address(address_parameters, address, msg.protocol_magic)
return CardanoAddress(address=address)

@ -3,13 +3,12 @@ from typing import TYPE_CHECKING
from . import seed
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetNativeScriptHash, CardanoNativeScriptHash
@seed.with_keychain
async def get_native_script_hash(
ctx: Context, msg: CardanoGetNativeScriptHash, keychain: seed.Keychain
msg: CardanoGetNativeScriptHash, keychain: seed.Keychain
) -> CardanoNativeScriptHash:
from trezor.messages import CardanoNativeScriptHash
from trezor.enums import CardanoNativeScriptHashDisplayFormat
@ -20,7 +19,7 @@ async def get_native_script_hash(
script_hash = native_script.get_native_script_hash(keychain, msg.script)
if msg.display_format != CardanoNativeScriptHashDisplayFormat.HIDE:
await layout.show_native_script(ctx, msg.script)
await layout.show_script_hash(ctx, script_hash, msg.display_format)
await layout.show_native_script(msg.script)
await layout.show_script_hash(script_hash, msg.display_format)
return CardanoNativeScriptHash(script_hash=script_hash)

@ -4,13 +4,12 @@ from ubinascii import hexlify
from . import seed
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetPublicKey, CardanoPublicKey
@seed.with_keychain
async def get_public_key(
ctx: Context, msg: CardanoGetPublicKey, keychain: seed.Keychain
msg: CardanoGetPublicKey, keychain: seed.Keychain
) -> CardanoPublicKey:
from trezor import log, wire
from trezor.ui.layouts import show_pubkey
@ -20,7 +19,6 @@ async def get_public_key(
address_n = msg.address_n # local_cache_attribute
await paths.validate_path(
ctx,
keychain,
address_n,
# path must match the PUBKEY schema
@ -35,7 +33,7 @@ async def get_public_key(
raise wire.ProcessError("Deriving public key failed")
if msg.show_display:
await show_pubkey(ctx, hexlify(key.node.public_key).decode())
await show_pubkey(hexlify(key.node.public_key).decode())
return key

@ -25,7 +25,6 @@ from .helpers.utils import (
if TYPE_CHECKING:
from typing import Literal
from trezor.wire import Context
from trezor import messages
from trezor.enums import CardanoNativeScriptHashDisplayFormat
from trezor.ui.layouts import PropertyType
@ -79,7 +78,6 @@ def format_coin_amount(amount: int, network_id: int) -> str:
async def show_native_script(
ctx: Context,
script: messages.CardanoNativeScript,
indices: list[int] | None = None,
) -> None:
@ -140,7 +138,6 @@ async def show_native_script(
append((f"Contains {len(scripts)} nested scripts.", None))
await confirm_properties(
ctx,
"verify_script",
"Verify script",
props,
@ -148,11 +145,10 @@ async def show_native_script(
)
for i, sub_script in enumerate(scripts):
await show_native_script(ctx, sub_script, indices + [i + 1])
await show_native_script(sub_script, indices + [i + 1])
async def show_script_hash(
ctx: Context,
script_hash: bytes,
display_format: CardanoNativeScriptHashDisplayFormat,
) -> None:
@ -165,7 +161,6 @@ async def show_script_hash(
if display_format == CardanoNativeScriptHashDisplayFormat.BECH32:
await confirm_properties(
ctx,
"verify_script",
"Verify script",
(("Script hash:", bech32.encode(bech32.HRP_SCRIPT_HASH, script_hash)),),
@ -173,7 +168,6 @@ async def show_script_hash(
)
elif display_format == CardanoNativeScriptHashDisplayFormat.POLICY_ID:
await layouts.confirm_blob(
ctx,
"verify_script",
"Verify script",
script_hash,
@ -182,9 +176,8 @@ async def show_script_hash(
)
async def show_tx_init(ctx: Context, title: str) -> bool:
async def show_tx_init(title: str) -> bool:
should_show_details = await layouts.should_show_more(
ctx,
"Confirm transaction",
(
(
@ -200,9 +193,8 @@ async def show_tx_init(ctx: Context, title: str) -> bool:
return should_show_details
async def confirm_input(ctx: Context, input: messages.CardanoTxInput) -> None:
async def confirm_input(input: messages.CardanoTxInput) -> None:
await confirm_properties(
ctx,
"confirm_input",
"Confirm transaction",
(
@ -214,7 +206,6 @@ async def confirm_input(ctx: Context, input: messages.CardanoTxInput) -> None:
async def confirm_sending(
ctx: Context,
ada_amount: int,
to: str,
output_type: Literal["address", "change", "collateral-return"],
@ -230,7 +221,6 @@ async def confirm_sending(
raise RuntimeError # should be unreachable
await layouts.confirm_output(
ctx,
to,
format_coin_amount(ada_amount, network_id),
title,
@ -238,13 +228,10 @@ async def confirm_sending(
)
async def confirm_sending_token(
ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None:
async def confirm_sending_token(policy_id: bytes, token: messages.CardanoToken) -> None:
assert token.amount is not None # _validate_token
await confirm_properties(
ctx,
"confirm_token",
"Confirm transaction",
(
@ -261,9 +248,8 @@ async def confirm_sending_token(
)
async def confirm_datum_hash(ctx: Context, datum_hash: bytes) -> None:
async def confirm_datum_hash(datum_hash: bytes) -> None:
await confirm_properties(
ctx,
"confirm_datum_hash",
"Confirm transaction",
(
@ -276,11 +262,8 @@ async def confirm_datum_hash(ctx: Context, datum_hash: bytes) -> None:
)
async def confirm_inline_datum(
ctx: Context, first_chunk: bytes, inline_datum_size: int
) -> None:
async def confirm_inline_datum(first_chunk: bytes, inline_datum_size: int) -> None:
await _confirm_data_chunk(
ctx,
"confirm_inline_datum",
"Inline datum",
first_chunk,
@ -289,10 +272,9 @@ async def confirm_inline_datum(
async def confirm_reference_script(
ctx: Context, first_chunk: bytes, reference_script_size: int
first_chunk: bytes, reference_script_size: int
) -> None:
await _confirm_data_chunk(
ctx,
"confirm_reference_script",
"Reference script",
first_chunk,
@ -301,7 +283,7 @@ async def confirm_reference_script(
async def _confirm_data_chunk(
ctx: Context, br_type: str, title: str, first_chunk: bytes, data_size: int
br_type: str, title: str, first_chunk: bytes, data_size: int
) -> None:
MAX_DISPLAYED_SIZE = 56
displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE]
@ -315,7 +297,6 @@ async def _confirm_data_chunk(
if data_size > MAX_DISPLAYED_SIZE:
props.append(("...", None))
await confirm_properties(
ctx,
br_type,
title="Confirm transaction",
props=props,
@ -324,39 +305,35 @@ async def _confirm_data_chunk(
async def show_credentials(
ctx: Context,
payment_credential: Credential,
stake_credential: Credential,
) -> None:
intro_text = "Address"
await _show_credential(ctx, payment_credential, intro_text, purpose="address")
await _show_credential(ctx, stake_credential, intro_text, purpose="address")
await _show_credential(payment_credential, intro_text, purpose="address")
await _show_credential(stake_credential, intro_text, purpose="address")
async def show_change_output_credentials(
ctx: Context,
payment_credential: Credential,
stake_credential: Credential,
) -> None:
intro_text = "The following address is a change address. Its"
await _show_credential(ctx, payment_credential, intro_text, purpose="output")
await _show_credential(ctx, stake_credential, intro_text, purpose="output")
await _show_credential(payment_credential, intro_text, purpose="output")
await _show_credential(stake_credential, intro_text, purpose="output")
async def show_device_owned_output_credentials(
ctx: Context,
payment_credential: Credential,
stake_credential: Credential,
show_both_credentials: bool,
) -> None:
intro_text = "The following address is owned by this device. Its"
await _show_credential(ctx, payment_credential, intro_text, purpose="output")
await _show_credential(payment_credential, intro_text, purpose="output")
if show_both_credentials:
await _show_credential(ctx, stake_credential, intro_text, purpose="output")
await _show_credential(stake_credential, intro_text, purpose="output")
async def show_cvote_registration_payment_credentials(
ctx: Context,
payment_credential: Credential,
stake_credential: Credential,
show_both_credentials: bool,
@ -366,12 +343,11 @@ async def show_cvote_registration_payment_credentials(
"The vote key registration payment address is owned by this device. Its"
)
await _show_credential(
ctx, payment_credential, intro_text, purpose="cvote_reg_payment_address"
payment_credential, intro_text, purpose="cvote_reg_payment_address"
)
if show_both_credentials or show_payment_warning:
extra_text = CVOTE_REWARD_ELIGIBILITY_WARNING if show_payment_warning else None
await _show_credential(
ctx,
stake_credential,
intro_text,
purpose="cvote_reg_payment_address",
@ -380,7 +356,6 @@ async def show_cvote_registration_payment_credentials(
async def _show_credential(
ctx: Context,
credential: Credential,
intro_text: str,
purpose: Literal["address", "output", "cvote_reg_payment_address"],
@ -428,7 +403,6 @@ async def _show_credential(
if len(props) > 0:
await confirm_properties(
ctx,
"confirm_credential",
title,
props,
@ -436,20 +410,17 @@ async def _show_credential(
)
async def warn_path(ctx: Context, path: list[int], title: str) -> None:
await layouts.confirm_path_warning(ctx, address_n_to_str(path), path_type=title)
async def warn_path(path: list[int], title: str) -> None:
await layouts.confirm_path_warning(address_n_to_str(path), path_type=title)
async def warn_tx_output_contains_tokens(
ctx: Context, is_collateral_return: bool = False
) -> None:
async def warn_tx_output_contains_tokens(is_collateral_return: bool = False) -> None:
content = (
"The collateral return output contains tokens."
if is_collateral_return
else "The following transaction output contains tokens."
)
await confirm_metadata(
ctx,
"confirm_tokens",
"Confirm transaction",
content,
@ -457,9 +428,8 @@ async def warn_tx_output_contains_tokens(
)
async def warn_tx_contains_mint(ctx: Context) -> None:
async def warn_tx_contains_mint() -> None:
await confirm_metadata(
ctx,
"confirm_tokens",
"Confirm transaction",
"The transaction contains minting or burning of tokens.",
@ -467,9 +437,8 @@ async def warn_tx_contains_mint(ctx: Context) -> None:
)
async def warn_tx_output_no_datum(ctx: Context) -> None:
async def warn_tx_output_no_datum() -> None:
await confirm_metadata(
ctx,
"confirm_no_datum_hash",
"Confirm transaction",
"The following transaction output contains a script address, but does not contain a datum.",
@ -477,9 +446,8 @@ async def warn_tx_output_no_datum(ctx: Context) -> None:
)
async def warn_no_script_data_hash(ctx: Context) -> None:
async def warn_no_script_data_hash() -> None:
await confirm_metadata(
ctx,
"confirm_no_script_data_hash",
"Confirm transaction",
"The transaction contains no script data hash. Plutus script will not be able to run.",
@ -487,9 +455,8 @@ async def warn_no_script_data_hash(ctx: Context) -> None:
)
async def warn_no_collateral_inputs(ctx: Context) -> None:
async def warn_no_collateral_inputs() -> None:
await confirm_metadata(
ctx,
"confirm_no_collateral_inputs",
"Confirm transaction",
"The transaction contains no collateral inputs. Plutus script will not be able to run.",
@ -497,9 +464,8 @@ async def warn_no_collateral_inputs(ctx: Context) -> None:
)
async def warn_unknown_total_collateral(ctx: Context) -> None:
async def warn_unknown_total_collateral() -> None:
await layouts.show_warning(
ctx,
"confirm_unknown_total_collateral",
"Unknown collateral amount.",
"Check all items carefully.",
@ -508,7 +474,6 @@ async def warn_unknown_total_collateral(ctx: Context) -> None:
async def confirm_witness_request(
ctx: Context,
witness_path: list[int],
) -> None:
from . import seed
@ -521,7 +486,6 @@ async def confirm_witness_request(
path_title = "path"
await layouts.confirm_text(
ctx,
"confirm_total",
"Confirm transaction",
address_n_to_str(witness_path),
@ -531,7 +495,6 @@ async def confirm_witness_request(
async def confirm_tx(
ctx: Context,
fee: int,
network_id: int,
protocol_magic: int,
@ -559,7 +522,6 @@ async def confirm_tx(
append(("Transaction ID:", tx_hash))
await confirm_properties(
ctx,
"confirm_total",
"Confirm transaction",
props,
@ -568,9 +530,7 @@ async def confirm_tx(
)
async def confirm_certificate(
ctx: Context, certificate: messages.CardanoTxCertificate
) -> None:
async def confirm_certificate(certificate: messages.CardanoTxCertificate) -> None:
# stake pool registration requires custom confirmation logic not covered
# in this call
assert certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION
@ -587,7 +547,6 @@ async def confirm_certificate(
props.append(("to pool:", format_stake_pool_id(certificate.pool)))
await confirm_properties(
ctx,
"confirm_certificate",
"Confirm transaction",
props,
@ -596,7 +555,6 @@ async def confirm_certificate(
async def confirm_stake_pool_parameters(
ctx: Context,
pool_parameters: messages.CardanoPoolParametersType,
network_id: int,
) -> None:
@ -605,7 +563,6 @@ async def confirm_stake_pool_parameters(
)
percentage_formatted = str(float(margin_percentage)).rstrip("0").rstrip(".")
await confirm_properties(
ctx,
"confirm_pool_registration",
"Confirm transaction",
(
@ -626,7 +583,6 @@ async def confirm_stake_pool_parameters(
async def confirm_stake_pool_owner(
ctx: Context,
keychain: Keychain,
owner: messages.CardanoPoolOwner,
protocol_magic: int,
@ -669,7 +625,6 @@ async def confirm_stake_pool_owner(
)
await confirm_properties(
ctx,
"confirm_pool_owners",
"Confirm transaction",
props,
@ -678,12 +633,10 @@ async def confirm_stake_pool_owner(
async def confirm_stake_pool_metadata(
ctx: Context,
metadata: messages.CardanoPoolMetadataType | None,
) -> None:
if metadata is None:
await confirm_properties(
ctx,
"confirm_pool_metadata",
"Confirm transaction",
(("Pool has no metadata (anonymous pool)", None),),
@ -692,7 +645,6 @@ async def confirm_stake_pool_metadata(
return
await confirm_properties(
ctx,
"confirm_pool_metadata",
"Confirm transaction",
(
@ -704,13 +656,11 @@ async def confirm_stake_pool_metadata(
async def confirm_stake_pool_registration_final(
ctx: Context,
protocol_magic: int,
ttl: int | None,
validity_interval_start: int | None,
) -> None:
await confirm_properties(
ctx,
"confirm_pool_final",
"Confirm transaction",
(
@ -725,7 +675,6 @@ async def confirm_stake_pool_registration_final(
async def confirm_withdrawal(
ctx: Context,
withdrawal: messages.CardanoTxWithdrawal,
address_bytes: bytes,
network_id: int,
@ -746,7 +695,6 @@ async def confirm_withdrawal(
props.append(("Amount:", format_coin_amount(withdrawal.amount, network_id)))
await confirm_properties(
ctx,
"confirm_withdrawal",
"Confirm transaction",
props,
@ -774,7 +722,6 @@ def _format_stake_credential(
async def confirm_cvote_registration_delegation(
ctx: Context,
public_key: str,
weight: int,
) -> None:
@ -786,7 +733,6 @@ async def confirm_cvote_registration_delegation(
props.append(("Weight:", str(weight)))
await confirm_properties(
ctx,
"confirm_cvote_registration_delegation",
title="Confirm transaction",
props=props,
@ -795,7 +741,6 @@ async def confirm_cvote_registration_delegation(
async def confirm_cvote_registration_payment_address(
ctx: Context,
payment_address: str,
should_show_payment_warning: bool,
) -> None:
@ -806,7 +751,6 @@ async def confirm_cvote_registration_payment_address(
if should_show_payment_warning:
props.append((CVOTE_REWARD_ELIGIBILITY_WARNING, None))
await confirm_properties(
ctx,
"confirm_cvote_registration_payment_address",
title="Confirm transaction",
props=props,
@ -815,7 +759,6 @@ async def confirm_cvote_registration_payment_address(
async def confirm_cvote_registration(
ctx: Context,
vote_public_key: str | None,
staking_path: list[int],
nonce: int,
@ -842,7 +785,6 @@ async def confirm_cvote_registration(
)
await confirm_properties(
ctx,
"confirm_cvote_registration",
title="Confirm transaction",
props=props,
@ -850,9 +792,8 @@ async def confirm_cvote_registration(
)
async def show_auxiliary_data_hash(ctx: Context, auxiliary_data_hash: bytes) -> None:
async def show_auxiliary_data_hash(auxiliary_data_hash: bytes) -> None:
await confirm_properties(
ctx,
"confirm_auxiliary_data",
"Confirm transaction",
(("Auxiliary data hash:", auxiliary_data_hash),),
@ -860,12 +801,9 @@ async def show_auxiliary_data_hash(ctx: Context, auxiliary_data_hash: bytes) ->
)
async def confirm_token_minting(
ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None:
async def confirm_token_minting(policy_id: bytes, token: messages.CardanoToken) -> None:
assert token.mint_amount is not None # _validate_token
await confirm_properties(
ctx,
"confirm_mint",
"Confirm transaction",
(
@ -885,9 +823,8 @@ async def confirm_token_minting(
)
async def warn_tx_network_unverifiable(ctx: Context) -> None:
async def warn_tx_network_unverifiable() -> None:
await confirm_metadata(
ctx,
"warning_no_outputs",
"Warning",
"Transaction has no outputs, network cannot be verified.",
@ -895,9 +832,8 @@ async def warn_tx_network_unverifiable(ctx: Context) -> None:
)
async def confirm_script_data_hash(ctx: Context, script_data_hash: bytes) -> None:
async def confirm_script_data_hash(script_data_hash: bytes) -> None:
await confirm_properties(
ctx,
"confirm_script_data_hash",
"Confirm transaction",
(
@ -911,10 +847,9 @@ async def confirm_script_data_hash(ctx: Context, script_data_hash: bytes) -> Non
async def confirm_collateral_input(
ctx: Context, collateral_input: messages.CardanoTxCollateralInput
collateral_input: messages.CardanoTxCollateralInput,
) -> None:
await confirm_properties(
ctx,
"confirm_collateral_input",
"Confirm transaction",
(
@ -926,10 +861,9 @@ async def confirm_collateral_input(
async def confirm_reference_input(
ctx: Context, reference_input: messages.CardanoTxReferenceInput
reference_input: messages.CardanoTxReferenceInput,
) -> None:
await confirm_properties(
ctx,
"confirm_reference_input",
"Confirm transaction",
(
@ -941,7 +875,7 @@ async def confirm_reference_input(
async def confirm_required_signer(
ctx: Context, required_signer: messages.CardanoTxRequiredSigner
required_signer: messages.CardanoTxRequiredSigner,
) -> None:
assert (
required_signer.key_hash is not None or required_signer.key_path
@ -953,7 +887,6 @@ async def confirm_required_signer(
)
await confirm_properties(
ctx,
"confirm_required_signer",
"Confirm transaction",
(("Required signer", formatted_signer),),
@ -962,7 +895,6 @@ async def confirm_required_signer(
async def show_cardano_address(
ctx: Context,
address_parameters: messages.CardanoAddressParametersType,
address: str,
protocol_magic: int,
@ -989,7 +921,6 @@ async def show_cardano_address(
path = address_n_to_str(address_parameters.address_n_staking)
await layouts.show_address(
ctx,
address,
path=path,
account=account,

@ -27,7 +27,7 @@ if TYPE_CHECKING:
)
MsgIn = TypeVar("MsgIn", bound=CardanoMessages)
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[MsgIn, "Keychain"], Awaitable[MsgOut]]
class Keychain:
@ -136,9 +136,7 @@ def derive_and_store_secrets(passphrase: str) -> None:
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from apps.common.seed import derive_and_store_roots
from trezor.enums import CardanoDerivationType
@ -146,7 +144,7 @@ async def _get_keychain_bip39(
raise wire.NotInitialized("Device is not initialized")
if derivation_type == CardanoDerivationType.LEDGER:
seed = await get_seed(ctx)
seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed))
if not cache.get(cache.APP_COMMON_DERIVE_CARDANO):
@ -160,7 +158,7 @@ async def _get_keychain_bip39(
# _get_secret
secret = cache.get(cache_entry)
if secret is None:
await derive_and_store_roots(ctx)
await derive_and_store_roots()
secret = cache.get(cache_entry)
assert secret is not None
@ -168,20 +166,18 @@ async def _get_keychain_bip39(
return Keychain(root)
async def _get_keychain(
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
async def _get_keychain(derivation_type: CardanoDerivationType) -> Keychain:
if mnemonic.is_bip39():
return await _get_keychain_bip39(ctx, derivation_type)
return await _get_keychain_bip39(derivation_type)
else:
# derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0022.md
seed = await get_seed(ctx)
seed = await get_seed()
return Keychain(cardano.from_seed_slip23(seed))
def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await _get_keychain(ctx, msg.derivation_type)
return await func(ctx, msg, keychain)
async def wrapper(msg: MsgIn) -> MsgOut:
keychain = await _get_keychain(msg.derivation_type)
return await func(msg, keychain)
return wrapper

@ -4,13 +4,12 @@ from .. import seed
if TYPE_CHECKING:
from typing import Type
from trezor.wire import Context
from trezor.messages import CardanoSignTxFinished, CardanoSignTxInit
@seed.with_keychain
async def sign_tx(
ctx: Context, msg: CardanoSignTxInit, keychain: seed.Keychain
msg: CardanoSignTxInit, keychain: seed.Keychain
) -> CardanoSignTxFinished:
from trezor.messages import CardanoSignTxFinished
from trezor import log, wire
@ -40,7 +39,7 @@ async def sign_tx(
else:
raise RuntimeError # should be unreachable
signer = signer_type(ctx, msg, keychain)
signer = signer_type(msg, keychain)
try:
await signer.sign()

@ -33,7 +33,6 @@ class MultisigSigner(Signer):
# super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx(
self.ctx,
msg.fee,
msg.network_id,
msg.protocol_magic,

@ -34,7 +34,6 @@ class OrdinarySigner(Signer):
# super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx(
self.ctx,
msg.fee,
msg.network_id,
msg.protocol_magic,
@ -92,10 +91,10 @@ class OrdinarySigner(Signer):
is_minting = SCHEMA_MINT.match(witness_path)
if is_minting:
await layout.confirm_witness_request(self.ctx, witness_path)
await layout.confirm_witness_request(witness_path)
elif not is_payment and not is_staking:
await self._fail_or_warn_path(witness_path, WITNESS_PATH_NAME)
else:
await self._show_if_showing_details(
layout.confirm_witness_request(self.ctx, witness_path)
layout.confirm_witness_request(witness_path)
)

@ -22,12 +22,12 @@ class PlutusSigner(Signer):
# These items should be present if a Plutus script is to be executed.
if self.msg.script_data_hash is None:
await layout.warn_no_script_data_hash(self.ctx)
await layout.warn_no_script_data_hash()
if self.msg.collateral_inputs_count == 0:
await layout.warn_no_collateral_inputs(self.ctx)
await layout.warn_no_collateral_inputs()
if self.msg.total_collateral is None:
await layout.warn_unknown_total_collateral(self.ctx)
await layout.warn_unknown_total_collateral()
async def _confirm_tx(self, tx_hash: bytes) -> None:
msg = self.msg # local_cache_attribute
@ -38,7 +38,6 @@ class PlutusSigner(Signer):
# tedious to check one by one on the Trezor screen).
is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx(
self.ctx,
msg.fee,
msg.network_id,
msg.protocol_magic,
@ -52,7 +51,7 @@ class PlutusSigner(Signer):
async def _show_input(self, input: messages.CardanoTxInput) -> None:
# super() omitted intentionally
# The inputs are not interchangeable (because of datums), so we must show them.
await self._show_if_showing_details(layout.confirm_input(self.ctx, input))
await self._show_if_showing_details(layout.confirm_input(input))
async def _show_output_credentials(
self, address_parameters: messages.CardanoAddressParametersType
@ -64,7 +63,6 @@ class PlutusSigner(Signer):
# evaluation. We at least hide the staking path if it matches the payment path.
show_both_credentials = should_show_credentials(address_parameters)
await layout.show_device_owned_output_credentials(
self.ctx,
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
show_both_credentials,

@ -45,7 +45,6 @@ class PoolOwnerSigner(Signer):
# super() omitted intentionally
await layout.confirm_stake_pool_registration_final(
self.ctx,
self.msg.protocol_magic,
self.msg.ttl,
self.msg.validity_interval_start,

@ -9,6 +9,7 @@ from trezor.enums import (
)
from trezor.messages import CardanoTxItemAck, CardanoTxOutput
from trezor.wire import DataError, ProcessError
from trezor.wire.context import call as ctx_call
from apps.common import safety_checks
@ -21,7 +22,6 @@ from ..helpers.utils import derive_public_key
if TYPE_CHECKING:
from typing import Any, Awaitable, ClassVar
from trezor.wire import Context
from trezor.enums import CardanoAddressType
from apps.common.paths import PathSchema
from apps.common import cbor
@ -76,13 +76,11 @@ class Signer:
def __init__(
self,
ctx: Context,
msg: messages.CardanoSignTxInit,
keychain: seed.Keychain,
) -> None:
from ..helpers.account_path_check import AccountPathChecker
self.ctx = ctx
self.msg = msg
self.keychain = keychain
@ -124,8 +122,8 @@ class Signer:
await self._confirm_tx(tx_hash)
response_after_witness_requests = await self._process_witness_requests(tx_hash)
await self.ctx.call(response_after_witness_requests, messages.CardanoTxHostAck)
await self.ctx.call(
await ctx_call(response_after_witness_requests, messages.CardanoTxHostAck)
await ctx_call(
messages.CardanoTxBodyHash(tx_hash=tx_hash), messages.CardanoTxHostAck
)
@ -225,12 +223,10 @@ class Signer:
validate_network_info(msg.network_id, msg.protocol_magic)
async def _show_tx_init(self) -> None:
self.should_show_details = await layout.show_tx_init(
self.ctx, self.SIGNING_MODE_TITLE
)
self.should_show_details = await layout.show_tx_init(self.SIGNING_MODE_TITLE)
if not self._is_network_id_verifiable():
await layout.warn_tx_network_unverifiable(self.ctx)
await layout.warn_tx_network_unverifiable()
async def _confirm_tx(self, tx_hash: bytes) -> None:
# Final signing confirmation is handled separately in each signing mode.
@ -242,7 +238,7 @@ class Signer:
self, inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None:
for _ in range(self.msg.inputs_count):
input: messages.CardanoTxInput = await self.ctx.call(
input: messages.CardanoTxInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxInput
)
self._validate_input(input)
@ -262,7 +258,7 @@ class Signer:
async def _process_outputs(self, outputs_list: HashBuilderList) -> None:
total_amount = 0
for _ in range(self.msg.outputs_count):
output: CardanoTxOutput = await self.ctx.call(
output: CardanoTxOutput = await ctx_call(
CardanoTxItemAck(), CardanoTxOutput
)
await self._process_output(outputs_list, output)
@ -346,10 +342,10 @@ class Signer:
and output.inline_datum_size == 0
and address_type in addresses.ADDRESS_TYPES_PAYMENT_SCRIPT
):
await layout.warn_tx_output_no_datum(self.ctx)
await layout.warn_tx_output_no_datum()
if output.asset_groups_count > 0:
await layout.warn_tx_output_contains_tokens(self.ctx)
await layout.warn_tx_output_contains_tokens()
if output.address_parameters is not None:
address = addresses.derive_human_readable(
@ -364,7 +360,6 @@ class Signer:
address = output.address
await layout.confirm_sending(
self.ctx,
output.amount,
address,
"change" if self._is_change_output(output) else "address",
@ -375,7 +370,6 @@ class Signer:
self, address_parameters: messages.CardanoAddressParametersType
) -> None:
await layout.show_change_output_credentials(
self.ctx,
Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters),
)
@ -439,7 +433,7 @@ class Signer:
if output.datum_hash is not None:
if should_show:
await self._show_if_showing_details(
layout.confirm_datum_hash(self.ctx, output.datum_hash)
layout.confirm_datum_hash(output.datum_hash)
)
output_list.append(output.datum_hash)
@ -472,7 +466,7 @@ class Signer:
if output.datum_hash is not None:
if should_show:
await self._show_if_showing_details(
layout.confirm_datum_hash(self.ctx, output.datum_hash)
layout.confirm_datum_hash(output.datum_hash)
)
add(
_BABBAGE_OUTPUT_KEY_DATUM_OPTION,
@ -532,7 +526,7 @@ class Signer:
should_show_tokens: bool,
) -> None:
for _ in range(asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call(
asset_group: messages.CardanoAssetGroup = await ctx_call(
CardanoTxItemAck(), messages.CardanoAssetGroup
)
self._validate_asset_group(asset_group)
@ -573,12 +567,12 @@ class Signer:
should_show_tokens: bool,
) -> None:
for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call(
token: messages.CardanoToken = await ctx_call(
CardanoTxItemAck(), messages.CardanoToken
)
self._validate_token(token)
if should_show_tokens:
await layout.confirm_sending_token(self.ctx, policy_id, token)
await layout.confirm_sending_token(policy_id, token)
assert token.amount is not None # _validate_token
tokens_dict.add(token.asset_name_bytes, token.amount)
@ -614,7 +608,7 @@ class Signer:
chunks_count = self._get_chunks_count(inline_datum_size)
for chunk_number in range(chunks_count):
chunk: messages.CardanoTxInlineDatumChunk = await self.ctx.call(
chunk: messages.CardanoTxInlineDatumChunk = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk
)
self._validate_chunk(
@ -625,7 +619,7 @@ class Signer:
)
if chunk_number == 0 and should_show:
await self._show_if_showing_details(
layout.confirm_inline_datum(self.ctx, chunk.data, inline_datum_size)
layout.confirm_inline_datum(chunk.data, inline_datum_size)
)
inline_datum_cbor.add(chunk.data)
@ -641,7 +635,7 @@ class Signer:
chunks_count = self._get_chunks_count(reference_script_size)
for chunk_number in range(chunks_count):
chunk: messages.CardanoTxReferenceScriptChunk = await self.ctx.call(
chunk: messages.CardanoTxReferenceScriptChunk = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk
)
self._validate_chunk(
@ -652,9 +646,7 @@ class Signer:
)
if chunk_number == 0 and should_show:
await self._show_if_showing_details(
layout.confirm_reference_script(
self.ctx, chunk.data, reference_script_size
)
layout.confirm_reference_script(chunk.data, reference_script_size)
)
reference_script_cbor.add(chunk.data)
@ -662,7 +654,7 @@ class Signer:
async def _process_certificates(self, certificates_list: HashBuilderList) -> None:
for _ in range(self.msg.certificates_count):
certificate: messages.CardanoTxCertificate = await self.ctx.call(
certificate: messages.CardanoTxCertificate = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxCertificate
)
self._validate_certificate(certificate)
@ -726,13 +718,13 @@ class Signer:
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
assert certificate.pool_parameters is not None
await layout.confirm_stake_pool_parameters(
self.ctx, certificate.pool_parameters, self.msg.network_id
certificate.pool_parameters, self.msg.network_id
)
await layout.confirm_stake_pool_metadata(
self.ctx, certificate.pool_parameters.metadata
certificate.pool_parameters.metadata
)
else:
await layout.confirm_certificate(self.ctx, certificate)
await layout.confirm_certificate(certificate)
# pool owners
@ -741,7 +733,7 @@ class Signer:
) -> None:
owners_as_path_count = 0
for _ in range(owners_count):
owner: messages.CardanoPoolOwner = await self.ctx.call(
owner: messages.CardanoPoolOwner = await ctx_call(
CardanoTxItemAck(), messages.CardanoPoolOwner
)
certificates.validate_pool_owner(owner, self.account_path_checker)
@ -764,7 +756,7 @@ class Signer:
)
await layout.confirm_stake_pool_owner(
self.ctx, self.keychain, owner, self.msg.protocol_magic, self.msg.network_id
self.keychain, owner, self.msg.protocol_magic, self.msg.network_id
)
# pool relays
@ -775,7 +767,7 @@ class Signer:
relays_count: int,
) -> None:
for _ in range(relays_count):
relay: messages.CardanoPoolRelayParameters = await self.ctx.call(
relay: messages.CardanoPoolRelayParameters = await ctx_call(
CardanoTxItemAck(), messages.CardanoPoolRelayParameters
)
certificates.validate_pool_relay(relay)
@ -787,14 +779,14 @@ class Signer:
self, withdrawals_dict: HashBuilderDict[bytes, int]
) -> None:
for _ in range(self.msg.withdrawals_count):
withdrawal: messages.CardanoTxWithdrawal = await self.ctx.call(
withdrawal: messages.CardanoTxWithdrawal = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxWithdrawal
)
self._validate_withdrawal(withdrawal)
address_bytes = self._derive_withdrawal_address_bytes(withdrawal)
await self._show_if_showing_details(
layout.confirm_withdrawal(
self.ctx, withdrawal, address_bytes, self.msg.network_id
withdrawal, address_bytes, self.msg.network_id
)
)
withdrawals_dict.add(address_bytes, withdrawal.amount)
@ -821,7 +813,7 @@ class Signer:
msg = self.msg # local_cache_attribute
data: messages.CardanoTxAuxiliaryData = await self.ctx.call(
data: messages.CardanoTxAuxiliaryData = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxAuxiliaryData
)
auxiliary_data.validate(data, msg.protocol_magic, msg.network_id)
@ -833,7 +825,6 @@ class Signer:
self.keychain, data, msg.protocol_magic, msg.network_id
)
await auxiliary_data.show(
self.ctx,
self.keychain,
auxiliary_data_hash,
data.cvote_registration_parameters,
@ -843,21 +834,21 @@ class Signer:
)
self.tx_dict.add(_TX_BODY_KEY_AUXILIARY_DATA, auxiliary_data_hash)
await self.ctx.call(auxiliary_data_supplement, messages.CardanoTxHostAck)
await ctx_call(auxiliary_data_supplement, messages.CardanoTxHostAck)
# minting
async def _process_minting(
self, minting_dict: HashBuilderDict[bytes, HashBuilderDict]
) -> None:
token_minting: messages.CardanoTxMint = await self.ctx.call(
token_minting: messages.CardanoTxMint = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxMint
)
await layout.warn_tx_contains_mint(self.ctx)
await layout.warn_tx_contains_mint()
for _ in range(token_minting.asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call(
asset_group: messages.CardanoAssetGroup = await ctx_call(
CardanoTxItemAck(), messages.CardanoAssetGroup
)
self._validate_asset_group(asset_group, is_mint=True)
@ -881,11 +872,11 @@ class Signer:
tokens_count: int,
) -> None:
for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call(
token: messages.CardanoToken = await ctx_call(
CardanoTxItemAck(), messages.CardanoToken
)
self._validate_token(token, is_mint=True)
await layout.confirm_token_minting(self.ctx, policy_id, token)
await layout.confirm_token_minting(policy_id, token)
assert token.mint_amount is not None # _validate_token
tokens.add(token.asset_name_bytes, token.mint_amount)
@ -896,7 +887,7 @@ class Signer:
assert self.msg.script_data_hash is not None
self._validate_script_data_hash()
await self._show_if_showing_details(
layout.confirm_script_data_hash(self.ctx, self.msg.script_data_hash)
layout.confirm_script_data_hash(self.msg.script_data_hash)
)
self.tx_dict.add(_TX_BODY_KEY_SCRIPT_DATA_HASH, self.msg.script_data_hash)
@ -913,7 +904,7 @@ class Signer:
self, collateral_inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None:
for _ in range(self.msg.collateral_inputs_count):
collateral_input: messages.CardanoTxCollateralInput = await self.ctx.call(
collateral_input: messages.CardanoTxCollateralInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxCollateralInput
)
self._validate_collateral_input(collateral_input)
@ -933,7 +924,7 @@ class Signer:
) -> None:
if self.msg.total_collateral is None:
await self._show_if_showing_details(
layout.confirm_collateral_input(self.ctx, collateral_input)
layout.confirm_collateral_input(collateral_input)
)
# required signers
@ -944,12 +935,12 @@ class Signer:
from ..helpers.utils import get_public_key_hash
for _ in range(self.msg.required_signers_count):
required_signer: messages.CardanoTxRequiredSigner = await self.ctx.call(
required_signer: messages.CardanoTxRequiredSigner = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxRequiredSigner
)
self._validate_required_signer(required_signer)
await self._show_if_showing_details(
layout.confirm_required_signer(self.ctx, required_signer)
layout.confirm_required_signer(required_signer)
)
key_hash = required_signer.key_hash or get_public_key_hash(
@ -985,9 +976,7 @@ class Signer:
# collateral return
async def _process_collateral_return(self) -> None:
output: CardanoTxOutput = await self.ctx.call(
CardanoTxItemAck(), CardanoTxOutput
)
output: CardanoTxOutput = await ctx_call(CardanoTxItemAck(), CardanoTxOutput)
self._validate_collateral_return(output)
should_show_init = self._should_show_collateral_return_init(output)
should_show_tokens = self._should_show_collateral_return_tokens(output)
@ -1031,9 +1020,7 @@ class Signer:
# We don't display missing datum warning since datums are forbidden.
if output.asset_groups_count > 0:
await layout.warn_tx_output_contains_tokens(
self.ctx, is_collateral_return=True
)
await layout.warn_tx_output_contains_tokens(is_collateral_return=True)
if output.address_parameters is not None:
address = addresses.derive_human_readable(
@ -1050,7 +1037,6 @@ class Signer:
address = output.address
await layout.confirm_sending(
self.ctx,
output.amount,
address,
"collateral-return",
@ -1078,12 +1064,12 @@ class Signer:
self, reference_inputs_list: HashBuilderList[tuple[bytes, int]]
) -> None:
for _ in range(self.msg.reference_inputs_count):
reference_input: messages.CardanoTxReferenceInput = await self.ctx.call(
reference_input: messages.CardanoTxReferenceInput = await ctx_call(
CardanoTxItemAck(), messages.CardanoTxReferenceInput
)
self._validate_reference_input(reference_input)
await self._show_if_showing_details(
layout.confirm_reference_input(self.ctx, reference_input)
layout.confirm_reference_input(reference_input)
)
reference_inputs_list.append(
(reference_input.prev_hash, reference_input.prev_index)
@ -1101,9 +1087,7 @@ class Signer:
response: CardanoTxResponseType = CardanoTxItemAck()
for _ in range(self.msg.witness_requests_count):
witness_request = await self.ctx.call(
response, messages.CardanoTxWitnessRequest
)
witness_request = await ctx_call(response, messages.CardanoTxWitnessRequest)
self._validate_witness_request(witness_request)
path = witness_request.path
await self._show_witness_request(path)
@ -1123,7 +1107,7 @@ class Signer:
self,
witness_path: list[int],
) -> None:
await layout.confirm_witness_request(self.ctx, witness_path)
await layout.confirm_witness_request(witness_path)
# helpers
@ -1241,7 +1225,7 @@ class Signer:
if safety_checks.is_strict():
raise DataError(f"Invalid {path_name.lower()}")
else:
await layout.warn_path(self.ctx, path, path_name)
await layout.warn_path(path, path_name)
def _fail_if_strict_and_unusual(
self, address_parameters: messages.CardanoAddressParametersType

@ -16,7 +16,6 @@ if TYPE_CHECKING:
from typing_extensions import Protocol
from trezor.protobuf import MessageType
from trezor.wire import Context
from .seed import Slip21Node
T = TypeVar("T")
@ -36,8 +35,8 @@ if TYPE_CHECKING:
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
Handler = Callable[[MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[MsgIn, "Keychain"], Awaitable[MsgOut]]
class Deletable(Protocol):
def __del__(self) -> None:
@ -176,14 +175,13 @@ class Keychain:
async def get_keychain(
ctx: Context,
curve: str,
schemas: Iterable[paths.PathSchemaType],
slip21_namespaces: Iterable[paths.Slip21Path] = (),
) -> Keychain:
from .seed import get_seed
seed = await get_seed(ctx)
seed = await get_seed()
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain
@ -205,10 +203,10 @@ def with_slip44_keychain(
schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, curve, schemas)
async def wrapper(msg: MsgIn) -> MsgOut:
keychain = await get_keychain(curve, schemas)
with keychain:
return await func(ctx, msg, keychain)
return await func(msg, keychain)
return wrapper

@ -1,12 +1,8 @@
from micropython import const
from typing import TYPE_CHECKING
import storage.device as storage_device
from trezor.wire import DataError
if TYPE_CHECKING:
from trezor.wire import Context
_MAX_PASSPHRASE_LEN = const(50)
@ -14,7 +10,7 @@ def is_enabled() -> bool:
return storage_device.is_passphrase_enabled()
async def get(ctx: Context) -> str:
async def get() -> str:
from trezor import workflow
if not is_enabled():
@ -24,23 +20,24 @@ async def get(ctx: Context) -> str:
if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host(ctx)
passphrase = await _request_on_host()
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _request_on_host(ctx: Context) -> str:
async def _request_on_host() -> str:
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import request_passphrase_on_host
from trezor.wire.context import call
request_passphrase_on_host()
request = PassphraseRequest()
ack = await ctx.call(request, PassphraseAck)
ack = await call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
@ -48,7 +45,7 @@ async def _request_on_host(ctx: Context) -> str:
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
if passphrase is None:
raise DataError(
@ -63,14 +60,12 @@ async def _request_on_host(ctx: Context) -> str:
if storage_device.get_hide_passphrase_from_host():
explanation = "Passphrase provided by host will be used but will not be displayed due to the device settings."
await confirm_action(
ctx,
"passphrase_host1_hidden",
"Hidden wallet",
description=f"Access hidden wallet?\n{explanation}",
)
else:
await confirm_action(
ctx,
"passphrase_host1",
"Hidden wallet",
description="Next screen will show the passphrase.",
@ -78,7 +73,6 @@ async def _request_on_host(ctx: Context) -> str:
)
await confirm_blob(
ctx,
"passphrase_host2",
"Confirm passphrase",
passphrase,

@ -15,7 +15,6 @@ if TYPE_CHECKING:
TypeVar,
)
from typing_extensions import Protocol
from trezor import wire
Bip32Path = Sequence[int]
Slip21Path = Sequence[bytes]
@ -342,20 +341,19 @@ PATTERN_CASA = "m/45'/coin_type/account/change/address_index"
async def validate_path(
ctx: wire.Context,
keychain: KeychainValidatorType,
path: Bip32Path,
*additional_checks: bool,
) -> None:
keychain.verify_path(path)
if not keychain.is_in_keychain(path) or not all(additional_checks):
await show_path_warning(ctx, path)
await show_path_warning(path)
async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None:
async def show_path_warning(path: Bip32Path) -> None:
from trezor.ui.layouts import confirm_path_warning
await confirm_path_warning(ctx, address_n_to_str(path))
await confirm_path_warning(address_n_to_str(path))
def is_hardened(i: int) -> bool:

@ -1,16 +1,12 @@
import utime
from typing import TYPE_CHECKING
from typing import Any, NoReturn
import storage.cache as storage_cache
from trezor import config, utils, wire
if TYPE_CHECKING:
from typing import Any, NoReturn
from trezor.wire import Context, GenericContext
async def _request_sd_salt(
ctx: wire.GenericContext, raise_cancelled_on_unavailable: bool = False
raise_cancelled_on_unavailable: bool = False,
) -> bytearray | None:
"""Helper to get SD salt in a general manner, working for all models.
@ -23,7 +19,7 @@ async def _request_sd_salt(
from .sdcard import request_sd_salt, SdCardUnavailable
try:
return await request_sd_salt(ctx)
return await request_sd_salt()
except SdCardUnavailable:
if raise_cancelled_on_unavailable:
raise wire.PinCancelled("SD salt is unavailable")
@ -43,38 +39,37 @@ def can_lock_device() -> bool:
async def request_pin(
ctx: GenericContext,
prompt: str,
attempts_remaining: int | None = None,
allow_cancel: bool = True,
) -> str:
from trezor.ui.layouts import request_pin_on_device
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
return await request_pin_on_device(prompt, attempts_remaining, allow_cancel)
async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str:
async def request_pin_confirm(*args: Any, **kwargs: Any) -> str:
from trezor.ui.layouts import pin_mismatch_popup, confirm_reenter_pin
while True:
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
await confirm_reenter_pin(ctx)
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
pin1 = await request_pin("Enter new PIN", *args, **kwargs)
await confirm_reenter_pin()
pin2 = await request_pin("Re-enter new PIN", *args, **kwargs)
if pin1 == pin2:
return pin1
await pin_mismatch_popup(ctx)
await pin_mismatch_popup()
async def request_pin_and_sd_salt(
ctx: Context, prompt: str, allow_cancel: bool = True
prompt: str, allow_cancel: bool = True
) -> tuple[str, bytearray | None]:
if config.has_pin():
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
pin = await request_pin(prompt, config.get_pin_rem(), allow_cancel)
config.ensure_not_wipe_code(pin)
else:
pin = ""
salt = await _request_sd_salt(ctx)
salt = await _request_sd_salt()
return pin, salt
@ -85,7 +80,6 @@ def _set_last_unlock_time() -> None:
async def verify_user_pin(
ctx: GenericContext = wire.DUMMY_CONTEXT,
prompt: str = "Enter PIN",
allow_cancel: bool = True,
retry: bool = True,
@ -107,14 +101,12 @@ async def verify_user_pin(
if config.has_pin():
from trezor.ui.layouts import request_pin_on_device
pin = await request_pin_on_device(
ctx, prompt, config.get_pin_rem(), allow_cancel
)
pin = await request_pin_on_device(prompt, config.get_pin_rem(), allow_cancel)
config.ensure_not_wipe_code(pin)
else:
pin = ""
salt = await _request_sd_salt(ctx, raise_cancelled_on_unavailable=True)
salt = await _request_sd_salt(raise_cancelled_on_unavailable=True)
if config.unlock(pin, salt):
_set_last_unlock_time()
return
@ -123,7 +115,7 @@ async def verify_user_pin(
while retry:
pin = await request_pin_on_device( # type: ignore ["request_pin_on_device" is possibly unbound]
ctx, "Enter PIN", config.get_pin_rem(), allow_cancel, wrong_pin=True
"Enter PIN", config.get_pin_rem(), allow_cancel, wrong_pin=True
)
if config.unlock(pin, salt):
_set_last_unlock_time()
@ -132,11 +124,10 @@ async def verify_user_pin(
raise wire.PinInvalid
async def error_pin_invalid(ctx: Context) -> NoReturn:
async def error_pin_invalid() -> NoReturn:
from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise(
ctx,
"warning_wrong_pin",
"The PIN you have entered is not valid.",
"Wrong PIN", # header
@ -145,11 +136,10 @@ async def error_pin_invalid(ctx: Context) -> NoReturn:
assert False
async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn:
async def error_pin_matches_wipe_code() -> NoReturn:
from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise(
ctx,
"warning_invalid_new_pin",
"The new PIN must be different from your wipe code.",
"Invalid PIN", # header

@ -7,10 +7,9 @@ class SdCardUnavailable(wire.ProcessError):
pass
async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
async def _confirm_retry_wrong_card() -> None:
if SD_CARD_HOT_SWAPPABLE:
await confirm_action(
ctx,
"warning_wrong_sd",
"SD card protection",
"Wrong SD card.",
@ -21,7 +20,6 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
)
else:
await show_error_and_raise(
ctx,
"warning_wrong_sd",
"Please unplug the device and insert the correct SD card.",
"Wrong SD card.",
@ -29,10 +27,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
)
async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
async def _confirm_retry_insert_card() -> None:
if SD_CARD_HOT_SWAPPABLE:
await confirm_action(
ctx,
"warning_no_sd",
"SD card protection",
"SD card required.",
@ -43,7 +40,6 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
)
else:
await show_error_and_raise(
ctx,
"warning_no_sd",
"Please unplug the device and insert your SD card.",
"SD card required.",
@ -51,10 +47,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
)
async def _confirm_format_card(ctx: wire.GenericContext) -> None:
async def _confirm_format_card() -> None:
# Format card? yes/no
await confirm_action(
ctx,
"warning_format_sd",
"SD card error",
"Unknown filesystem.",
@ -66,7 +61,6 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
# Confirm formatting
await confirm_action(
ctx,
"confirm_format_sd",
"Format SD card",
"All data on the SD card will be lost.",
@ -79,11 +73,9 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
async def confirm_retry_sd(
ctx: wire.GenericContext,
exc: wire.ProcessError = SdCardUnavailable("Error accessing SD card."),
) -> None:
await confirm_action(
ctx,
"warning_sd_retry",
"SD card problem",
None,
@ -94,9 +86,7 @@ async def confirm_retry_sd(
)
async def ensure_sdcard(
ctx: wire.GenericContext, ensure_filesystem: bool = True
) -> None:
async def ensure_sdcard(ensure_filesystem: bool = True) -> None:
"""Ensure a SD card is ready for use.
This function runs the UI flow needed to ask the user to insert a SD card if there
@ -109,7 +99,7 @@ async def ensure_sdcard(
from trezor import sdcard
while not sdcard.is_present():
await _confirm_retry_insert_card(ctx)
await _confirm_retry_insert_card()
if not ensure_filesystem:
return
@ -126,7 +116,7 @@ async def ensure_sdcard(
# no error when mounting
return
await _confirm_format_card(ctx)
await _confirm_format_card()
# Proceed to formatting. Failure is caught by the outside OSError handler
with sdcard.filesystem(mounted=False):
@ -139,26 +129,24 @@ async def ensure_sdcard(
except OSError:
# formatting failed, or generic I/O error (SD card power-on failed)
await confirm_retry_sd(ctx)
await confirm_retry_sd()
async def request_sd_salt(
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
) -> bytearray | None:
async def request_sd_salt() -> bytearray | None:
import storage.sd_salt as storage_sd_salt
if not storage_sd_salt.is_enabled():
return None
while True:
await ensure_sdcard(ctx, ensure_filesystem=False)
await ensure_sdcard(ensure_filesystem=False)
try:
return storage_sd_salt.load_sd_salt()
except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
await _confirm_retry_wrong_card(ctx)
await _confirm_retry_wrong_card()
except OSError:
# Generic problem with loading the SD salt (hardware problem, or we could
# not read the file, or there is a staged salt which cannot be committed).
# In either case, there is no good way to recover. If the user clicks Retry,
# we will try again.
await confirm_retry_sd(ctx)
await confirm_retry_sd()

@ -10,7 +10,6 @@ from .passphrase import get as get_passphrase
if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path
from trezor.wire import Context
from trezor.crypto import bip32
@ -50,7 +49,7 @@ if not utils.BITCOIN_ONLY:
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: Context) -> None:
async def derive_and_store_roots() -> None:
from trezor import wire
if not storage_device.is_initialized():
@ -64,7 +63,7 @@ if not utils.BITCOIN_ONLY:
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase(ctx)
passphrase = await get_passphrase()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
@ -76,8 +75,8 @@ if not utils.BITCOIN_ONLY:
derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes:
await derive_and_store_roots(ctx)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
@ -87,8 +86,8 @@ else:
# We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes:
passphrase = await get_passphrase(ctx)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)

@ -11,6 +11,7 @@ if __debug__:
from trezor import log, loop, utils, wire
from trezor.ui import display
from trezor.wire import context
from trezor.enums import MessageType
from trezor.messages import (
DebugLinkLayout,
@ -47,7 +48,7 @@ if __debug__:
layout_change_chan = loop.chan()
DEBUG_CONTEXT: wire.Context | None = None
DEBUG_CONTEXT: context.Context | None = None
LAYOUT_WATCHER_NONE = 0
LAYOUT_WATCHER_STATE = 1
@ -139,9 +140,7 @@ if __debug__:
await DEBUG_CONTEXT.write(DebugLinkState(tokens=content_tokens))
storage.layout_watcher = LAYOUT_WATCHER_NONE
async def dispatch_DebugLinkWatchLayout(
ctx: wire.Context, msg: DebugLinkWatchLayout
) -> Success:
async def dispatch_DebugLinkWatchLayout(msg: DebugLinkWatchLayout) -> Success:
from trezor import ui
layout_change_chan.putters.clear()
@ -152,16 +151,14 @@ if __debug__:
return Success()
async def dispatch_DebugLinkResetDebugEvents(
ctx: wire.Context, msg: DebugLinkResetDebugEvents
msg: DebugLinkResetDebugEvents,
) -> Success:
# Resetting the debug events makes sure that the previous
# events/layouts are not mixed with the new ones.
storage.reset_debug_events()
return Success()
async def dispatch_DebugLinkDecision(
ctx: wire.Context, msg: DebugLinkDecision
) -> None:
async def dispatch_DebugLinkDecision(msg: DebugLinkDecision) -> None:
from trezor import workflow
workflow.idle_timer.touch()
@ -194,7 +191,7 @@ if __debug__:
loop.schedule(return_layout_change())
async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState
msg: DebugLinkGetState,
) -> DebugLinkState | None:
from trezor.messages import DebugLinkState
from apps.common import mnemonic, passphrase
@ -218,9 +215,7 @@ if __debug__:
return m
async def dispatch_DebugLinkRecordScreen(
ctx: wire.Context, msg: DebugLinkRecordScreen
) -> Success:
async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
if msg.target_directory:
# In case emulator is restarted but we still want to record screenshots
# into the same directory as before, we need to increment the refresh index,
@ -235,18 +230,14 @@ if __debug__:
return Success()
async def dispatch_DebugLinkReseedRandom(
ctx: wire.Context, msg: DebugLinkReseedRandom
) -> Success:
async def dispatch_DebugLinkReseedRandom(msg: DebugLinkReseedRandom) -> Success:
if msg.value is not None:
from trezor.crypto import random
random.reseed(msg.value)
return Success()
async def dispatch_DebugLinkEraseSdCard(
ctx: wire.Context, msg: DebugLinkEraseSdCard
) -> Success:
async def dispatch_DebugLinkEraseSdCard(msg: DebugLinkEraseSdCard) -> Success:
from trezor import io
sdcard = io.sdcard # local_cache_attribute
@ -271,8 +262,8 @@ if __debug__:
def boot() -> None:
register = workflow_handlers.register # local_cache_attribute
register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore [Argument of type "(ctx: Context, msg: DebugLinkDecision) -> Coroutine[Any, Any, None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) # type: ignore [Argument of type "(ctx: Context, msg: DebugLinkGetState) -> Coroutine[Any, Any, DebugLinkState | None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore [Argument of type "(msg: DebugLinkDecision) -> Coroutine[Any, Any, None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) # type: ignore [Argument of type "(msg: DebugLinkGetState) -> Coroutine[Any, Any, DebugLinkState | None]" cannot be assigned to parameter "handler" of type "Handler[Msg@register]" in function "register"]
register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom)
register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen)
register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import LoadDevice, Success
from trezor.wire import Context
async def load_device(ctx: Context, msg: LoadDevice) -> Success:
async def load_device(msg: LoadDevice) -> Success:
import storage.device as storage_device
from trezor import config
from trezor.crypto import bip39, slip39
@ -38,7 +37,6 @@ async def load_device(ctx: Context, msg: LoadDevice) -> Success:
# _warn
await confirm_action(
ctx,
"warn_loading_seed",
"Loading seed",
"Loading private seed is not recommended.",

@ -1,13 +1,12 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor import wire
from trezor.utils import Writer, HashWriter
from trezor.messages import EosTxActionAck
async def process_action(
ctx: wire.Context, sha: HashWriter, action: EosTxActionAck, is_last: bool
sha: HashWriter, action: EosTxActionAck, is_last: bool
) -> None:
from .. import helpers, writers
from . import layout
@ -26,71 +25,70 @@ async def process_action(
if account == "eosio":
if name == "buyram":
assert action.buy_ram is not None # _check_action
await layout.confirm_action_buyram(ctx, action.buy_ram)
await layout.confirm_action_buyram(action.buy_ram)
writers.write_action_buyram(w, action.buy_ram)
elif name == "buyrambytes":
assert action.buy_ram_bytes is not None # _check_action
await layout.confirm_action_buyrambytes(ctx, action.buy_ram_bytes)
await layout.confirm_action_buyrambytes(action.buy_ram_bytes)
writers.write_action_buyrambytes(w, action.buy_ram_bytes)
elif name == "sellram":
assert action.sell_ram is not None # _check_action
await layout.confirm_action_sellram(ctx, action.sell_ram)
await layout.confirm_action_sellram(action.sell_ram)
writers.write_action_sellram(w, action.sell_ram)
elif name == "delegatebw":
assert action.delegate is not None # _check_action
await layout.confirm_action_delegate(ctx, action.delegate)
await layout.confirm_action_delegate(action.delegate)
writers.write_action_delegate(w, action.delegate)
elif name == "undelegatebw":
assert action.undelegate is not None # _check_action
await layout.confirm_action_undelegate(ctx, action.undelegate)
await layout.confirm_action_undelegate(action.undelegate)
writers.write_action_undelegate(w, action.undelegate)
elif name == "refund":
assert action.refund is not None # _check_action
await layout.confirm_action_refund(ctx, action.refund)
await layout.confirm_action_refund(action.refund)
writers.write_action_refund(w, action.refund)
elif name == "voteproducer":
assert action.vote_producer is not None # _check_action
await layout.confirm_action_voteproducer(ctx, action.vote_producer)
await layout.confirm_action_voteproducer(action.vote_producer)
writers.write_action_voteproducer(w, action.vote_producer)
elif name == "updateauth":
assert action.update_auth is not None # _check_action
await layout.confirm_action_updateauth(ctx, action.update_auth)
await layout.confirm_action_updateauth(action.update_auth)
writers.write_action_updateauth(w, action.update_auth)
elif name == "deleteauth":
assert action.delete_auth is not None # _check_action
await layout.confirm_action_deleteauth(ctx, action.delete_auth)
await layout.confirm_action_deleteauth(action.delete_auth)
writers.write_action_deleteauth(w, action.delete_auth)
elif name == "linkauth":
assert action.link_auth is not None # _check_action
await layout.confirm_action_linkauth(ctx, action.link_auth)
await layout.confirm_action_linkauth(action.link_auth)
writers.write_action_linkauth(w, action.link_auth)
elif name == "unlinkauth":
assert action.unlink_auth is not None # _check_action
await layout.confirm_action_unlinkauth(ctx, action.unlink_auth)
await layout.confirm_action_unlinkauth(action.unlink_auth)
writers.write_action_unlinkauth(w, action.unlink_auth)
elif name == "newaccount":
assert action.new_account is not None # _check_action
await layout.confirm_action_newaccount(ctx, action.new_account)
await layout.confirm_action_newaccount(action.new_account)
writers.write_action_newaccount(w, action.new_account)
else:
raise ValueError("Unrecognized action type for eosio")
elif name == "transfer":
assert action.transfer is not None # _check_action
await layout.confirm_action_transfer(ctx, action.transfer, account)
await layout.confirm_action_transfer(action.transfer, account)
writers.write_action_transfer(w, action.transfer)
else:
await _process_unknown_action(ctx, w, action)
await _process_unknown_action(w, action)
writers.write_action_common(sha, action.common)
writers.write_bytes_prefixed(sha, w)
async def _process_unknown_action(
ctx: wire.Context, w: Writer, action: EosTxActionAck
) -> None:
async def _process_unknown_action(w: Writer, action: EosTxActionAck) -> None:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from trezor.messages import EosTxActionAck, EosTxActionRequest
from trezor.wire.context import call
from .. import writers
from . import layout
@ -106,9 +104,7 @@ async def _process_unknown_action(
bytes_left = unknown.data_size - len(data_chunk)
while bytes_left != 0:
action = await ctx.call(
EosTxActionRequest(data_size=bytes_left), EosTxActionAck
)
action = await call(EosTxActionRequest(data_size=bytes_left), EosTxActionAck)
if unknown is None:
raise ValueError("Bad response. Unknown struct expected.")
@ -120,7 +116,7 @@ async def _process_unknown_action(
if bytes_left < 0:
raise ValueError("Bad response. Buffer overflow.")
await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest())
await layout.confirm_action_unknown(action.common, checksum.get_digest())
def _check_action(action: EosTxActionAck, name: str, account: str) -> bool:

@ -7,7 +7,6 @@ from ..helpers import eos_asset_to_string, eos_name_to_string
if TYPE_CHECKING:
from typing import Iterable
from trezor.wire import Context
from trezor.messages import (
EosActionBuyRam,
EosActionBuyRamBytes,
@ -35,13 +34,11 @@ is_last = False
# Because icon and br_code are almost always the same
# (and also calling with positional arguments takes less space)
async def _confirm_properties(
ctx: Context,
br_type: str,
title: str,
props: Iterable[PropertyType],
) -> None:
await confirm_properties(
ctx,
br_type,
title,
props,
@ -50,9 +47,8 @@ async def _confirm_properties(
)
async def confirm_action_buyram(ctx: Context, msg: EosActionBuyRam) -> None:
async def confirm_action_buyram(msg: EosActionBuyRam) -> None:
await _confirm_properties(
ctx,
"confirm_buyram",
"Buy RAM",
(
@ -63,9 +59,8 @@ async def confirm_action_buyram(ctx: Context, msg: EosActionBuyRam) -> None:
)
async def confirm_action_buyrambytes(ctx: Context, msg: EosActionBuyRamBytes) -> None:
async def confirm_action_buyrambytes(msg: EosActionBuyRamBytes) -> None:
await _confirm_properties(
ctx,
"confirm_buyrambytes",
"Buy RAM",
(
@ -76,7 +71,7 @@ async def confirm_action_buyrambytes(ctx: Context, msg: EosActionBuyRamBytes) ->
)
async def confirm_action_delegate(ctx: Context, msg: EosActionDelegate) -> None:
async def confirm_action_delegate(msg: EosActionDelegate) -> None:
props = [
("Sender:", eos_name_to_string(msg.sender)),
("Receiver:", eos_name_to_string(msg.receiver)),
@ -91,16 +86,14 @@ async def confirm_action_delegate(ctx: Context, msg: EosActionDelegate) -> None:
append(("Transfer:", "No"))
await _confirm_properties(
ctx,
"confirm_delegate",
"Delegate",
props,
)
async def confirm_action_sellram(ctx: Context, msg: EosActionSellRam) -> None:
async def confirm_action_sellram(msg: EosActionSellRam) -> None:
await _confirm_properties(
ctx,
"confirm_sellram",
"Sell RAM",
(
@ -110,9 +103,8 @@ async def confirm_action_sellram(ctx: Context, msg: EosActionSellRam) -> None:
)
async def confirm_action_undelegate(ctx: Context, msg: EosActionUndelegate) -> None:
async def confirm_action_undelegate(msg: EosActionUndelegate) -> None:
await _confirm_properties(
ctx,
"confirm_undelegate",
"Undelegate",
(
@ -124,22 +116,20 @@ async def confirm_action_undelegate(ctx: Context, msg: EosActionUndelegate) -> N
)
async def confirm_action_refund(ctx: Context, msg: EosActionRefund) -> None:
async def confirm_action_refund(msg: EosActionRefund) -> None:
await _confirm_properties(
ctx,
"confirm_refund",
"Refund",
(("Owner:", eos_name_to_string(msg.owner)),),
)
async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer) -> None:
async def confirm_action_voteproducer(msg: EosActionVoteProducer) -> None:
producers = msg.producers # local_cache_attribute
if msg.proxy and not producers:
# PROXY
await _confirm_properties(
ctx,
"confirm_voteproducer",
"Vote for proxy",
(
@ -151,7 +141,6 @@ async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer)
elif producers:
# PRODUCERS
await _confirm_properties(
ctx,
"confirm_voteproducer",
"Vote for producers",
(
@ -163,16 +152,13 @@ async def confirm_action_voteproducer(ctx: Context, msg: EosActionVoteProducer)
else:
# Cancel vote
await _confirm_properties(
ctx,
"confirm_voteproducer",
"Cancel vote",
(("Voter:", eos_name_to_string(msg.voter)),),
)
async def confirm_action_transfer(
ctx: Context, msg: EosActionTransfer, account: str
) -> None:
async def confirm_action_transfer(msg: EosActionTransfer, account: str) -> None:
props = [
("From:", eos_name_to_string(msg.sender)),
("To:", eos_name_to_string(msg.receiver)),
@ -182,14 +168,13 @@ async def confirm_action_transfer(
if msg.memo is not None:
props.append(("Memo", msg.memo[:512]))
await _confirm_properties(
ctx,
"confirm_transfer",
"Transfer",
props,
)
async def confirm_action_updateauth(ctx: Context, msg: EosActionUpdateAuth) -> None:
async def confirm_action_updateauth(msg: EosActionUpdateAuth) -> None:
props: list[PropertyType] = [
("Account:", eos_name_to_string(msg.account)),
("Permission:", eos_name_to_string(msg.permission)),
@ -197,16 +182,14 @@ async def confirm_action_updateauth(ctx: Context, msg: EosActionUpdateAuth) -> N
]
props.extend(authorization_fields(msg.auth))
await _confirm_properties(
ctx,
"confirm_updateauth",
"Update Auth",
props,
)
async def confirm_action_deleteauth(ctx: Context, msg: EosActionDeleteAuth) -> None:
async def confirm_action_deleteauth(msg: EosActionDeleteAuth) -> None:
await _confirm_properties(
ctx,
"confirm_deleteauth",
"Delete Auth",
(
@ -216,9 +199,8 @@ async def confirm_action_deleteauth(ctx: Context, msg: EosActionDeleteAuth) -> N
)
async def confirm_action_linkauth(ctx: Context, msg: EosActionLinkAuth) -> None:
async def confirm_action_linkauth(msg: EosActionLinkAuth) -> None:
await _confirm_properties(
ctx,
"confirm_linkauth",
"Link Auth",
(
@ -230,9 +212,8 @@ async def confirm_action_linkauth(ctx: Context, msg: EosActionLinkAuth) -> None:
)
async def confirm_action_unlinkauth(ctx: Context, msg: EosActionUnlinkAuth) -> None:
async def confirm_action_unlinkauth(msg: EosActionUnlinkAuth) -> None:
await _confirm_properties(
ctx,
"confirm_unlinkauth",
"Unlink Auth",
(
@ -243,7 +224,7 @@ async def confirm_action_unlinkauth(ctx: Context, msg: EosActionUnlinkAuth) -> N
)
async def confirm_action_newaccount(ctx: Context, msg: EosActionNewAccount) -> None:
async def confirm_action_newaccount(msg: EosActionNewAccount) -> None:
props: list[PropertyType] = [
("Creator:", eos_name_to_string(msg.creator)),
("Name:", eos_name_to_string(msg.name)),
@ -251,18 +232,14 @@ async def confirm_action_newaccount(ctx: Context, msg: EosActionNewAccount) -> N
props.extend(authorization_fields(msg.owner))
props.extend(authorization_fields(msg.active))
await _confirm_properties(
ctx,
"confirm_newaccount",
"New Account",
props,
)
async def confirm_action_unknown(
ctx: Context, action: EosActionCommon, checksum: bytes
) -> None:
async def confirm_action_unknown(action: EosActionCommon, checksum: bytes) -> None:
await confirm_properties(
ctx,
"confirm_unknown",
"Arbitrary data",
(

@ -5,20 +5,17 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import EosGetPublicKey, EosPublicKey
from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__)
async def get_public_key(
ctx: Context, msg: EosGetPublicKey, keychain: Keychain
) -> EosPublicKey:
async def get_public_key(msg: EosGetPublicKey, keychain: Keychain) -> EosPublicKey:
from trezor.crypto.curve import secp256k1
from trezor.messages import EosPublicKey
from apps.common import paths
from .helpers import public_key_to_wif
from .layout import require_get_public_key
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
@ -26,5 +23,5 @@ async def get_public_key(
wif = public_key_to_wif(public_key)
if msg.show_display:
await require_get_public_key(ctx, wif)
await require_get_public_key(wif)
return EosPublicKey(wif_public_key=wif, raw_public_key=public_key)

@ -1,22 +1,15 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire import Context
async def require_get_public_key(ctx: Context, public_key: str) -> None:
async def require_get_public_key(public_key: str) -> None:
from trezor.ui.layouts import show_pubkey
await show_pubkey(ctx, public_key)
await show_pubkey(public_key)
async def require_sign_tx(ctx: Context, num_actions: int) -> None:
async def require_sign_tx(num_actions: int) -> None:
from trezor.enums import ButtonRequestType
from trezor.strings import format_plural
from trezor.ui.layouts import confirm_action
await confirm_action(
ctx,
"confirm_tx",
"Sign transaction",
description="You are about to sign {}.",

@ -3,14 +3,14 @@ from typing import TYPE_CHECKING
from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import EosSignTx, EosSignedTx
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSignedTx:
async def sign_tx(msg: EosSignTx, keychain: Keychain) -> EosSignedTx:
from trezor.wire import DataError
from trezor.wire.context import call
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.messages import EosSignedTx, EosTxActionAck, EosTxActionRequest
@ -26,7 +26,7 @@ async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSigned
if not num_actions:
raise DataError("No actions")
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
sha = HashWriter(sha256())
@ -36,13 +36,13 @@ async def sign_tx(ctx: Context, msg: EosSignTx, keychain: Keychain) -> EosSigned
write_header(sha, msg.header)
write_uvarint(sha, 0)
write_uvarint(sha, num_actions)
await require_sign_tx(ctx, num_actions)
await require_sign_tx(num_actions)
# actions
for index in range(num_actions):
action = await ctx.call(EosTxActionRequest(), EosTxActionAck)
action = await call(EosTxActionRequest(), EosTxActionAck)
is_last = index == num_actions - 1
await process_action(ctx, sha, action, is_last)
await process_action(sha, action, is_last)
write_uvarint(sha, 0)
write_bytes_fixed(sha, bytearray(32), 32)

@ -4,7 +4,6 @@ from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress, EthereumAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
@ -12,7 +11,6 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def get_address(
ctx: Context,
msg: EthereumGetAddress,
keychain: Keychain,
defs: Definitions,
@ -24,13 +22,13 @@ async def get_address(
address_n = msg.address_n # local_cache_attribute
await paths.validate_path(ctx, keychain, address_n)
await paths.validate_path(keychain, address_n)
node = keychain.derive(address_n)
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n))
await show_address(address, path=paths.address_n_to_str(address_n))
return EthereumAddress(address=address)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
from trezor.wire import Context
async def get_public_key(ctx: Context, msg: EthereumGetPublicKey) -> EthereumPublicKey:
async def get_public_key(msg: EthereumGetPublicKey) -> EthereumPublicKey:
from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, GetPublicKey
from trezor.ui.layouts import show_pubkey
@ -13,9 +12,9 @@ async def get_public_key(ctx: Context, msg: EthereumGetPublicKey) -> EthereumPub
# we use the Bitcoin format for Ethereum xpubs
btc_pubkey_msg = GetPublicKey(address_n=msg.address_n)
resp = await bitcoin_get_public_key.get_public_key(ctx, btc_pubkey_msg)
resp = await bitcoin_get_public_key.get_public_key(btc_pubkey_msg)
if msg.show_display:
await show_pubkey(ctx, hexlify(resp.node.public_key).decode())
await show_pubkey(hexlify(resp.node.public_key).decode())
return EthereumPublicKey(node=resp.node, xpub=resp.xpub)

@ -12,8 +12,6 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.messages import (
EthereumGetAddress,
EthereumSignMessage,
@ -36,7 +34,7 @@ if TYPE_CHECKING:
)
HandlerAddressN = Callable[
[Context, MsgInAddressN, Keychain, definitions.Definitions],
[MsgInAddressN, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
@ -48,7 +46,7 @@ if TYPE_CHECKING:
)
HandlerChainId = Callable[
[Context, MsgInSignTx, Keychain, definitions.Definitions],
[MsgInSignTx, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
@ -122,13 +120,13 @@ def with_keychain_from_path(
def decorator(
func: HandlerAddressN[MsgInAddressN, MsgOut]
) -> Handler[MsgInAddressN, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInAddressN) -> MsgOut:
async def wrapper(msg: MsgInAddressN) -> MsgOut:
slip44 = _slip44_from_address_n(msg.address_n)
defs = _defs_from_message(msg, slip44=slip44)
schemas = _schemas_from_network(patterns, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
keychain = await get_keychain(CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain, defs)
return await func(msg, keychain, defs)
return wrapper
@ -139,11 +137,11 @@ def with_keychain_from_chain_id(
func: HandlerChainId[MsgInSignTx, MsgOut]
) -> Handler[MsgInSignTx, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInSignTx) -> MsgOut:
async def wrapper(msg: MsgInSignTx) -> MsgOut:
defs = _defs_from_message(msg, chain_id=msg.chain_id)
schemas = _schemas_from_network(PATTERNS_ADDRESS, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
keychain = await get_keychain(CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain, defs)
return await func(msg, keychain, defs)
return wrapper

@ -22,11 +22,9 @@ if TYPE_CHECKING:
EthereumStructMember,
EthereumTokenInfo,
)
from trezor.wire import Context
def require_confirm_tx(
ctx: Context,
to_bytes: bytes,
value: int,
network: EthereumNetworkInfo,
@ -40,7 +38,6 @@ def require_confirm_tx(
else:
to_str = "new contract?"
return confirm_output(
ctx,
to_str,
format_ethereum_amount(value, token, network),
br_code=ButtonRequestType.SignTx,
@ -48,7 +45,6 @@ def require_confirm_tx(
async def require_confirm_fee(
ctx: Context,
spending: int,
gas_price: int,
gas_limit: int,
@ -56,13 +52,11 @@ async def require_confirm_fee(
token: EthereumTokenInfo | None,
) -> None:
await confirm_amount(
ctx,
title="Confirm fee",
description="Gas price:",
amount=format_ethereum_amount(gas_price, None, network),
)
await confirm_total(
ctx,
total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:",
@ -71,7 +65,6 @@ async def require_confirm_fee(
async def require_confirm_eip1559_fee(
ctx: Context,
spending: int,
max_priority_fee: int,
max_gas_fee: int,
@ -80,19 +73,16 @@ async def require_confirm_eip1559_fee(
token: EthereumTokenInfo | None,
) -> None:
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas",
)
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas",
)
await confirm_total(
ctx,
format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:",
@ -100,15 +90,12 @@ async def require_confirm_eip1559_fee(
)
def require_confirm_unknown_token(
ctx: Context, address_bytes: bytes
) -> Awaitable[None]:
def require_confirm_unknown_token(address_bytes: bytes) -> Awaitable[None]:
from ubinascii import hexlify
from trezor.ui.layouts import confirm_address
contract_address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address(
ctx,
"Unknown token",
contract_address_hex,
"Contract:",
@ -117,22 +104,20 @@ def require_confirm_unknown_token(
)
def require_confirm_address(ctx: Context, address_bytes: bytes) -> Awaitable[None]:
def require_confirm_address(address_bytes: bytes) -> Awaitable[None]:
from ubinascii import hexlify
from trezor.ui.layouts import confirm_address
address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address(
ctx,
"Signing address",
address_hex,
br_code=ButtonRequestType.SignTx,
)
def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitable[None]:
def require_confirm_data(data: bytes, data_total: int) -> Awaitable[None]:
return confirm_blob(
ctx,
"confirm_data",
"Confirm data",
data,
@ -142,11 +127,10 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
)
async def confirm_typed_data_final(ctx: Context) -> None:
async def confirm_typed_data_final() -> None:
from trezor.ui.layouts import confirm_action
await confirm_action(
ctx,
"confirm_typed_data_final",
"Confirm typed data",
"Really sign EIP-712 typed data?",
@ -155,9 +139,8 @@ async def confirm_typed_data_final(ctx: Context) -> None:
)
def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
def confirm_empty_typed_message() -> Awaitable[None]:
return confirm_text(
ctx,
"confirm_empty_typed_message",
"Confirm message",
"",
@ -165,7 +148,7 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
)
async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
async def should_show_domain(name: bytes, version: bytes) -> bool:
domain_name = decode_typed_data(name, "string")
domain_version = decode_typed_data(version, "string")
@ -175,7 +158,6 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
(ui.DEMIBOLD, domain_version),
)
return await should_show_more(
ctx,
"Confirm domain",
para,
"Show full domain",
@ -184,7 +166,6 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
async def should_show_struct(
ctx: Context,
description: str,
data_members: list[EthereumStructMember],
title: str = "Confirm struct",
@ -199,7 +180,6 @@ async def should_show_struct(
(ui.NORMAL, ", ".join(field.name for field in data_members)),
)
return await should_show_more(
ctx,
title,
para,
button_text,
@ -208,14 +188,12 @@ async def should_show_struct(
async def should_show_array(
ctx: Context,
parent_objects: Iterable[str],
data_type: str,
size: int,
) -> bool:
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
return await should_show_more(
ctx,
limit_str(".".join(parent_objects)),
para,
"Show full array",
@ -224,7 +202,6 @@ async def should_show_array(
async def confirm_typed_value(
ctx: Context,
name: str,
value: bytes,
parent_objects: list[str],
@ -247,7 +224,6 @@ async def confirm_typed_value(
if field.data_type in (EthereumDataType.ADDRESS, EthereumDataType.BYTES):
await confirm_blob(
ctx,
"confirm_typed_value",
title,
data,
@ -256,7 +232,6 @@ async def confirm_typed_value(
)
else:
await confirm_text(
ctx,
"confirm_typed_value",
title,
data,

@ -7,7 +7,6 @@ if TYPE_CHECKING:
EthereumSignMessage,
EthereumMessageSignature,
)
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
@ -27,7 +26,6 @@ def message_digest(message: bytes) -> bytes:
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_message(
ctx: Context,
msg: EthereumSignMessage,
keychain: Keychain,
defs: Definitions,
@ -41,13 +39,11 @@ async def sign_message(
from .helpers import address_from_bytes
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=False
)
await confirm_signverify("ETH", decode_message(msg.message), address, verify=False)
signature = secp256k1.sign(
node.private_key(),

@ -10,7 +10,6 @@ from .keychain import with_keychain_from_chain_id
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo
from trezor.wire import Context
from .definitions import Definitions
from .keychain import MsgInSignTx
@ -24,7 +23,6 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id
async def sign_tx(
ctx: Context,
msg: EthereumSignTx,
keychain: Keychain,
defs: Definitions,
@ -45,19 +43,18 @@ async def sign_tx(
raise DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
token, address_bytes, recipient, value = await handle_erc20(msg, defs)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, defs.network, token)
await require_confirm_tx(recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
await require_confirm_data(msg.data_initial_chunk, data_total)
await require_confirm_fee(
ctx,
value,
int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"),
@ -87,7 +84,7 @@ async def sign_tx(
sha.extend(data)
while data_left > 0:
resp = await send_request_chunk(ctx, data_left)
resp = await send_request_chunk(data_left)
data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk)
@ -103,7 +100,6 @@ async def sign_tx(
async def handle_erc20(
ctx: Context,
msg: MsgInSignTx,
definitions: Definitions,
) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]:
@ -127,7 +123,7 @@ async def handle_erc20(
value = int.from_bytes(data_initial_chunk[36:68], "big")
if token is tokens.UNKNOWN_TOKEN:
await require_confirm_unknown_token(ctx, address_bytes)
await require_confirm_unknown_token(address_bytes)
return token, address_bytes, recipient, value
@ -157,13 +153,14 @@ def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
return length
async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck:
async def send_request_chunk(data_left: int) -> EthereumTxAck:
from trezor.messages import EthereumTxAck
from trezor.wire.context import call
# TODO: layoutProgress ?
req = EthereumTxRequest()
req.data_length = min(data_left, 1024)
return await ctx.call(req, EthereumTxAck)
return await call(req, EthereumTxAck)
def _sign_digest(

@ -12,7 +12,6 @@ if TYPE_CHECKING:
EthereumAccessList,
EthereumTxRequest,
)
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
@ -30,7 +29,6 @@ def access_list_item_length(item: EthereumAccessList) -> int:
@with_keychain_from_chain_id
async def sign_tx_eip1559(
ctx: Context,
msg: EthereumSignTxEIP1559,
keychain: Keychain,
defs: Definitions,
@ -56,19 +54,18 @@ async def sign_tx_eip1559(
raise wire.DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
token, address_bytes, recipient, value = await handle_erc20(msg, defs)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, defs.network, token)
await require_confirm_tx(recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
await require_confirm_data(msg.data_initial_chunk, data_total)
await require_confirm_eip1559_fee(
ctx,
value,
int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"),
@ -108,7 +105,7 @@ async def sign_tx_eip1559(
sha.extend(data)
while data_left > 0:
resp = await send_request_chunk(ctx, data_left)
resp = await send_request_chunk(data_left)
data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk)

@ -2,6 +2,7 @@ from typing import TYPE_CHECKING
from trezor.enums import EthereumDataType
from trezor.wire import DataError
from trezor.wire.context import call
from .helpers import get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
@ -9,7 +10,6 @@ from .layout import should_show_struct
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.utils import HashWriter
from .definitions import Definitions
@ -23,7 +23,6 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_typed_data(
ctx: Context,
msg: EthereumSignTypedData,
keychain: Keychain,
defs: Definitions,
@ -34,16 +33,16 @@ async def sign_typed_data(
from .layout import require_confirm_address
from trezor.messages import EthereumTypedDataSignature
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
address_bytes: bytes = node.ethereum_pubkeyhash()
# Display address so user can validate it
await require_confirm_address(ctx, address_bytes)
await require_confirm_address(address_bytes)
data_hash = await _generate_typed_data_hash(
ctx, msg.primary_type, msg.metamask_v4_compat
msg.primary_type, msg.metamask_v4_compat
)
signature = secp256k1.sign(
@ -57,7 +56,7 @@ async def sign_typed_data(
async def _generate_typed_data_hash(
ctx: Context, primary_type: str, metamask_v4_compat: bool = True
primary_type: str, metamask_v4_compat: bool = True
) -> bytes:
"""
Generate typed data hash according to EIP-712 specification
@ -72,14 +71,13 @@ async def _generate_typed_data_hash(
)
typed_data_envelope = TypedDataEnvelope(
ctx,
primary_type,
metamask_v4_compat,
)
await typed_data_envelope.collect_types()
name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope)
show_domain = await should_show_domain(ctx, name, version)
name, version = await _get_name_and_version_for_domain(typed_data_envelope)
show_domain = await should_show_domain(name, version)
domain_separator = await typed_data_envelope.hash_struct(
"EIP712Domain",
[0],
@ -91,11 +89,10 @@ async def _generate_typed_data_hash(
# In this case, we ignore the "message" part and only use the "domain" part
# https://ethereum-magicians.org/t/eip-712-standards-clarification-primarytype-as-domaintype/3286
if primary_type == "EIP712Domain":
await confirm_empty_typed_message(ctx)
await confirm_empty_typed_message()
message_hash = b""
else:
show_message = await should_show_struct(
ctx,
primary_type,
typed_data_envelope.types[primary_type].members,
"Confirm message",
@ -108,7 +105,7 @@ async def _generate_typed_data_hash(
[primary_type],
)
await confirm_typed_data_final(ctx)
await confirm_typed_data_final()
return keccak256(b"\x19\x01" + domain_separator + message_hash)
@ -131,11 +128,9 @@ class TypedDataEnvelope:
def __init__(
self,
ctx: Context,
primary_type: str,
metamask_v4_compat: bool,
) -> None:
self.ctx = ctx
self.primary_type = primary_type
self.metamask_v4_compat = metamask_v4_compat
self.types: dict[str, EthereumTypedDataStructAck] = {}
@ -153,7 +148,7 @@ class TypedDataEnvelope:
)
req = EthereumTypedDataStructRequest(name=type_name)
current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
current_type = await call(req, EthereumTypedDataStructAck)
self.types[type_name] = current_type
for member in current_type.members:
member_type = member.type
@ -254,8 +249,6 @@ class TypedDataEnvelope:
"""
from .layout import confirm_typed_value, should_show_array
ctx = self.ctx # local_cache_attribute
type_members = self.types[primary_type].members
member_value_path = member_path + [0]
current_parent_objects = parent_objects + [""]
@ -272,7 +265,6 @@ class TypedDataEnvelope:
if show_data:
show_struct = await should_show_struct(
ctx,
struct_name, # description
self.types[struct_name].members, # data_members
".".join(current_parent_objects), # title
@ -290,7 +282,7 @@ class TypedDataEnvelope:
elif field_type.data_type == EthereumDataType.ARRAY:
# Getting the length of the array first, if not fixed
if field_type.size is None:
array_size = await _get_array_size(ctx, member_value_path)
array_size = await _get_array_size(member_value_path)
else:
array_size = field_type.size
@ -300,7 +292,6 @@ class TypedDataEnvelope:
if show_data:
show_array = await should_show_array(
ctx,
current_parent_objects,
get_type_name(entry_type),
array_size,
@ -338,11 +329,10 @@ class TypedDataEnvelope:
current_parent_objects,
)
else:
value = await get_value(ctx, entry_type, el_member_path)
value = await get_value(entry_type, el_member_path)
encode_field(arr_w, entry_type, value)
if show_array:
await confirm_typed_value(
ctx,
field_name,
value,
parent_objects,
@ -351,11 +341,10 @@ class TypedDataEnvelope:
)
w.extend(arr_w.get_digest())
else:
value = await get_value(ctx, field_type, member_value_path)
value = await get_value(field_type, member_value_path)
encode_field(w, field_type, value)
if show_data:
await confirm_typed_value(
ctx,
field_name,
value,
parent_objects,
@ -503,18 +492,17 @@ def validate_field_type(field: EthereumFieldType) -> None:
raise DataError("Unexpected size in str/bool/addr")
async def _get_array_size(ctx: Context, member_path: list[int]) -> int:
async def _get_array_size(member_path: list[int]) -> int:
"""Get the length of an array at specific `member_path` from the client."""
from trezor.messages import EthereumFieldType
# Field type for getting the array length from client, so we can check the return value
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
length_value = await get_value(ARRAY_LENGTH_TYPE, member_path)
return int.from_bytes(length_value, "big")
async def get_value(
ctx: Context,
field: EthereumFieldType,
member_value_path: list[int],
) -> bytes:
@ -524,7 +512,7 @@ async def get_value(
req = EthereumTypedDataValueRequest(
member_path=member_value_path,
)
res = await ctx.call(req, EthereumTypedDataValueAck)
res = await call(req, EthereumTypedDataValueAck)
value = res.value
_validate_value(field=field, value=value)
@ -533,7 +521,7 @@ async def get_value(
async def _get_name_and_version_for_domain(
ctx: Context, typed_data_envelope: TypedDataEnvelope
typed_data_envelope: TypedDataEnvelope,
) -> tuple[bytes, bytes]:
domain_name = b"unknown"
domain_version = b"unknown"
@ -543,8 +531,8 @@ async def _get_name_and_version_for_domain(
for member_index, member in enumerate(domain_members):
member_value_path[-1] = member_index
if member.name == "name":
domain_name = await get_value(ctx, member.type, member_value_path)
domain_name = await get_value(member.type, member_value_path)
elif member.name == "version":
domain_version = await get_value(ctx, member.type, member_value_path)
domain_version = await get_value(member.type, member_value_path)
return domain_name, domain_version

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import EthereumVerifyMessage, Success
from trezor.wire import Context
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
async def verify_message(msg: EthereumVerifyMessage) -> Success:
from trezor.wire import DataError
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
@ -35,9 +34,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
address = address_from_bytes(address_bytes)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=True
)
await confirm_signverify("ETH", decode_message(msg.message), address, verify=True)
await show_success(ctx, "verify_message", "The signature is valid.")
await show_success("verify_message", "The signature is valid.")
return Success(message="Message verified")

@ -3,10 +3,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import Success
from trezor.messages import ApplyFlags
from trezor.wire import GenericContext
async def apply_flags(ctx: GenericContext, msg: ApplyFlags) -> Success:
async def apply_flags(msg: ApplyFlags) -> Success:
import storage.device
from storage.device import set_flags
from trezor.wire import NotInitialized

@ -9,7 +9,6 @@ import trezorui2
if TYPE_CHECKING:
from trezor.messages import ApplySettings, Success
from trezor.wire import Context, GenericContext
from trezor.enums import SafetyCheckLevel
@ -62,7 +61,7 @@ def _validate_homescreen(homescreen: bytes) -> None:
_validate_homescreen_model_specific(homescreen)
async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
async def apply_settings(msg: ApplySettings) -> Success:
import storage.device as storage_device
from apps.common import safety_checks
from trezor.messages import Success
@ -98,7 +97,7 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
if homescreen is not None:
_validate_homescreen(homescreen)
await _require_confirm_change_homescreen(ctx, homescreen)
await _require_confirm_change_homescreen(homescreen)
try:
storage_device.set_homescreen(homescreen)
except ValueError:
@ -107,19 +106,17 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
if label is not None:
if len(label) > storage_device.LABEL_MAXLENGTH:
raise DataError("Label too long")
await _require_confirm_change_label(ctx, label)
await _require_confirm_change_label(label)
storage_device.set_label(label)
if use_passphrase is not None:
await _require_confirm_change_passphrase(ctx, use_passphrase)
await _require_confirm_change_passphrase(use_passphrase)
storage_device.set_passphrase_enabled(use_passphrase)
if passphrase_always_on_device is not None:
if not storage_device.is_passphrase_enabled():
raise DataError("Passphrase is not enabled")
await _require_confirm_change_passphrase_source(
ctx, passphrase_always_on_device
)
await _require_confirm_change_passphrase_source(passphrase_always_on_device)
storage_device.set_passphrase_always_on_device(passphrase_always_on_device)
if auto_lock_delay_ms is not None:
@ -127,25 +124,25 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
raise ProcessError("Auto-lock delay too short")
if auto_lock_delay_ms > storage_device.AUTOLOCK_DELAY_MAXIMUM:
raise ProcessError("Auto-lock delay too long")
await _require_confirm_change_autolock_delay(ctx, auto_lock_delay_ms)
await _require_confirm_change_autolock_delay(auto_lock_delay_ms)
storage_device.set_autolock_delay_ms(auto_lock_delay_ms)
if msg_safety_checks is not None:
await _require_confirm_safety_checks(ctx, msg_safety_checks)
await _require_confirm_safety_checks(msg_safety_checks)
safety_checks.apply_setting(msg_safety_checks)
if display_rotation is not None:
await _require_confirm_change_display_rotation(ctx, display_rotation)
await _require_confirm_change_display_rotation(display_rotation)
storage_device.set_rotation(display_rotation)
if experimental_features is not None:
await _require_confirm_experimental_features(ctx, experimental_features)
await _require_confirm_experimental_features(experimental_features)
storage_device.set_experimental_features(experimental_features)
if hide_passphrase_from_host is not None:
if safety_checks.is_strict():
raise ProcessError("Safety checks are strict")
await _require_confirm_hide_passphrase_from_host(ctx, hide_passphrase_from_host)
await _require_confirm_hide_passphrase_from_host(hide_passphrase_from_host)
storage_device.set_hide_passphrase_from_host(hide_passphrase_from_host)
reload_settings_from_storage()
@ -153,12 +150,9 @@ async def apply_settings(ctx: Context, msg: ApplySettings) -> Success:
return Success(message="Settings applied")
async def _require_confirm_change_homescreen(
ctx: GenericContext, homescreen: bytes
) -> None:
async def _require_confirm_change_homescreen(homescreen: bytes) -> None:
if homescreen == b"":
await confirm_action(
ctx,
"set_homescreen",
"Set homescreen",
description="Do you really want to set default homescreen image?",
@ -166,14 +160,12 @@ async def _require_confirm_change_homescreen(
)
else:
await confirm_homescreen(
ctx,
homescreen,
)
async def _require_confirm_change_label(ctx: GenericContext, label: str) -> None:
async def _require_confirm_change_label(label: str) -> None:
await confirm_single(
ctx,
"set_label",
"Device name",
description="Change device name to {}?",
@ -182,11 +174,10 @@ async def _require_confirm_change_label(ctx: GenericContext, label: str) -> None
)
async def _require_confirm_change_passphrase(ctx: GenericContext, use: bool) -> None:
async def _require_confirm_change_passphrase(use: bool) -> None:
template = "Do you want to {} passphrase protection?"
description = template.format("enable" if use else "disable")
await confirm_action(
ctx,
"set_passphrase",
"Enable passphrase" if use else "Disable passphrase",
description=description,
@ -195,7 +186,7 @@ async def _require_confirm_change_passphrase(ctx: GenericContext, use: bool) ->
async def _require_confirm_change_passphrase_source(
ctx: GenericContext, passphrase_always_on_device: bool
passphrase_always_on_device: bool,
) -> None:
description = (
"Do you really want to enter passphrase always on the device?"
@ -203,7 +194,6 @@ async def _require_confirm_change_passphrase_source(
else "Do you want to revoke the passphrase on device setting?"
)
await confirm_action(
ctx,
"set_passphrase_source",
"Passphrase source",
description=description,
@ -211,9 +201,7 @@ async def _require_confirm_change_passphrase_source(
)
async def _require_confirm_change_display_rotation(
ctx: GenericContext, rotation: int
) -> None:
async def _require_confirm_change_display_rotation(rotation: int) -> None:
if rotation == 0:
label = "north"
elif rotation == 90:
@ -226,7 +214,6 @@ async def _require_confirm_change_display_rotation(
raise DataError("Unsupported display rotation")
await confirm_action(
ctx,
"set_rotation",
"Change rotation",
description="Do you want to change device rotation to {}?",
@ -235,13 +222,10 @@ async def _require_confirm_change_display_rotation(
)
async def _require_confirm_change_autolock_delay(
ctx: GenericContext, delay_ms: int
) -> None:
async def _require_confirm_change_autolock_delay(delay_ms: int) -> None:
from trezor.strings import format_duration_ms
await confirm_action(
ctx,
"set_autolock_delay",
"Auto-lock delay",
description="Do you really want to auto-lock your device after {}?",
@ -250,14 +234,11 @@ async def _require_confirm_change_autolock_delay(
)
async def _require_confirm_safety_checks(
ctx: GenericContext, level: SafetyCheckLevel
) -> None:
async def _require_confirm_safety_checks(level: SafetyCheckLevel) -> None:
from trezor.enums import SafetyCheckLevel
if level == SafetyCheckLevel.Strict:
await confirm_action(
ctx,
"set_safety_checks",
"Safety checks",
description="Do you really want to enforce strict safety checks (recommended)?",
@ -273,7 +254,6 @@ async def _require_confirm_safety_checks(
)
await confirm_action(
ctx,
"set_safety_checks",
"Safety override",
"Are you sure?",
@ -287,12 +267,9 @@ async def _require_confirm_safety_checks(
raise ValueError # enum value out of range
async def _require_confirm_experimental_features(
ctx: GenericContext, enable: bool
) -> None:
async def _require_confirm_experimental_features(enable: bool) -> None:
if enable:
await confirm_action(
ctx,
"set_experimental_features",
"Experimental mode",
"Only for development and beta testing!",
@ -302,12 +279,9 @@ async def _require_confirm_experimental_features(
)
async def _require_confirm_hide_passphrase_from_host(
ctx: GenericContext, enable: bool
) -> None:
async def _require_confirm_hide_passphrase_from_host(enable: bool) -> None:
if enable:
await confirm_action(
ctx,
"set_hide_passphrase_from_host",
"Hide passphrase",
description="Hide passphrase coming from host?",

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import BackupDevice, Success
from trezor.wire import Context
async def backup_device(ctx: Context, msg: BackupDevice) -> Success:
async def backup_device(msg: BackupDevice) -> Success:
import storage.device as storage_device
from trezor import wire
from trezor.messages import Success
@ -26,10 +25,10 @@ async def backup_device(ctx: Context, msg: BackupDevice) -> Success:
storage_device.set_unfinished_backup(True)
storage_device.set_backed_up()
await backup_seed(ctx, mnemonic_type, mnemonic_secret)
await backup_seed(mnemonic_type, mnemonic_secret)
storage_device.set_unfinished_backup(False)
await layout.show_backup_success(ctx)
await layout.show_backup_success()
return Success(message="Seed successfully backed up")

@ -6,10 +6,9 @@ if TYPE_CHECKING:
from typing import Awaitable
from trezor.messages import ChangePin, Success
from trezor.wire import Context
async def change_pin(ctx: Context, msg: ChangePin) -> Success:
async def change_pin(msg: ChangePin) -> Success:
from storage.device import is_initialized
from trezor.messages import Success
from trezor.ui.layouts import show_success
@ -25,28 +24,28 @@ async def change_pin(ctx: Context, msg: ChangePin) -> Success:
raise wire.NotInitialized("Device is not initialized")
# confirm that user wants to change the pin
await _require_confirm_change_pin(ctx, msg)
await _require_confirm_change_pin(msg)
# get old pin
curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
curpin, salt = await request_pin_and_sd_salt("Enter PIN")
# if changing pin, pre-check the entered pin before getting new pin
if curpin and not msg.remove:
if not config.check_pin(curpin, salt):
await error_pin_invalid(ctx)
await error_pin_invalid()
# get new pin
if not msg.remove:
newpin = await request_pin_confirm(ctx)
newpin = await request_pin_confirm()
else:
newpin = ""
# write into storage
if not config.change_pin(curpin, newpin, salt, salt):
if newpin:
await error_pin_matches_wipe_code(ctx)
await error_pin_matches_wipe_code()
else:
await error_pin_invalid(ctx)
await error_pin_invalid()
if newpin:
if curpin:
@ -59,11 +58,11 @@ async def change_pin(ctx: Context, msg: ChangePin) -> Success:
msg_screen = "PIN protection disabled."
msg_wire = "PIN removed"
await show_success(ctx, "success_pin", msg_screen)
await show_success("success_pin", msg_screen)
return Success(message=msg_wire)
def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]:
def _require_confirm_change_pin(msg: ChangePin) -> Awaitable[None]:
from trezor.ui.layouts import confirm_action, confirm_set_new_pin
has_pin = config.has_pin()
@ -73,7 +72,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if msg.remove and has_pin: # removing pin
return confirm_action(
ctx,
br_type,
title,
description="Do you want to disable PIN protection?",
@ -82,7 +80,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if not msg.remove and has_pin: # changing pin
return confirm_action(
ctx,
br_type,
title,
description="Do you want to change your PIN?",
@ -91,7 +88,6 @@ def _require_confirm_change_pin(ctx: Context, msg: ChangePin) -> Awaitable[None]
if not msg.remove and not has_pin: # setting new pin
return confirm_set_new_pin(
ctx,
br_type,
title,
"Do you want to enable PIN protection?",

@ -2,12 +2,11 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Awaitable
from trezor.wire import Context
from trezor.messages import ChangeWipeCode, Success
async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success:
async def change_wipe_code(msg: ChangeWipeCode) -> Success:
from storage.device import is_initialized
from trezor.wire import NotInitialized
from trezor.ui.layouts import show_success
@ -23,24 +22,24 @@ async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success:
# Confirm that user wants to set or remove the wipe code.
has_wipe_code = config.has_wipe_code()
await _require_confirm_action(ctx, msg, has_wipe_code)
await _require_confirm_action(msg, has_wipe_code)
# Get the unlocking PIN.
pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
pin, salt = await request_pin_and_sd_salt("Enter PIN")
if not msg.remove:
# Pre-check the entered PIN.
if config.has_pin() and not config.check_pin(pin, salt):
await error_pin_invalid(ctx)
await error_pin_invalid()
# Get new wipe code.
wipe_code = await _request_wipe_code_confirm(ctx, pin)
wipe_code = await _request_wipe_code_confirm(pin)
else:
wipe_code = ""
# Write into storage.
if not config.change_wipe_code(pin, salt, wipe_code):
await error_pin_invalid(ctx)
await error_pin_invalid()
if wipe_code:
if has_wipe_code:
@ -53,12 +52,12 @@ async def change_wipe_code(ctx: Context, msg: ChangeWipeCode) -> Success:
msg_screen = "Wipe code disabled."
msg_wire = "Wipe code removed"
await show_success(ctx, "success_wipe_code", msg_screen)
await show_success("success_wipe_code", msg_screen)
return Success(message=msg_wire)
def _require_confirm_action(
ctx: Context, msg: ChangeWipeCode, has_wipe_code: bool
msg: ChangeWipeCode, has_wipe_code: bool
) -> Awaitable[None]:
from trezor.wire import ProcessError
from trezor.ui.layouts import confirm_action, confirm_set_new_pin
@ -67,7 +66,6 @@ def _require_confirm_action(
if msg.remove and has_wipe_code:
return confirm_action(
ctx,
"disable_wipe_code",
title,
description="Do you want to disable wipe code protection?",
@ -76,7 +74,6 @@ def _require_confirm_action(
if not msg.remove and has_wipe_code:
return confirm_action(
ctx,
"change_wipe_code",
title,
description="Do you want to change the wipe code?",
@ -85,7 +82,6 @@ def _require_confirm_action(
if not msg.remove and not has_wipe_code:
return confirm_set_new_pin(
ctx,
"set_wipe_code",
title,
"Do you want to enable wipe code?",
@ -98,7 +94,7 @@ def _require_confirm_action(
raise ProcessError("Wipe code protection is already disabled")
async def _request_wipe_code_confirm(ctx: Context, pin: str) -> str:
async def _request_wipe_code_confirm(pin: str) -> str:
from apps.common.request_pin import request_pin
from trezor.ui.layouts import (
confirm_reenter_pin,
@ -107,12 +103,12 @@ async def _request_wipe_code_confirm(ctx: Context, pin: str) -> str:
)
while True:
code1 = await request_pin(ctx, "Enter new wipe code")
code1 = await request_pin("Enter new wipe code")
if code1 == pin:
await wipe_code_same_as_pin_popup(ctx)
await wipe_code_same_as_pin_popup()
continue
await confirm_reenter_pin(ctx, is_wipe_code=True)
code2 = await request_pin(ctx, "Re-enter wipe code")
await confirm_reenter_pin(is_wipe_code=True)
code2 = await request_pin("Re-enter wipe code")
if code1 == code2:
return code1
await pin_mismatch_popup(ctx, is_wipe_code=True)
await pin_mismatch_popup(is_wipe_code=True)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import GetNextU2FCounter, NextU2FCounter
from trezor.wire import Context
async def get_next_u2f_counter(ctx: Context, msg: GetNextU2FCounter) -> NextU2FCounter:
async def get_next_u2f_counter(msg: GetNextU2FCounter) -> NextU2FCounter:
import storage.device as storage_device
from trezor.wire import NotInitialized
from trezor.enums import ButtonRequestType
@ -16,7 +15,6 @@ async def get_next_u2f_counter(ctx: Context, msg: GetNextU2FCounter) -> NextU2FC
raise NotInitialized("Device is not initialized")
await confirm_action(
ctx,
"get_u2f_counter",
"Get next U2F counter",
description="Do you really want to increase and retrieve the U2F counter?",

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import GetNonce, Nonce
from trezor.wire import Context
async def get_nonce(ctx: Context, msg: GetNonce) -> Nonce:
async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache
from trezor.crypto import random
from trezor.messages import Nonce

@ -3,21 +3,21 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import RebootToBootloader
from typing import NoReturn
from trezor.wire import Context
async def reboot_to_bootloader(ctx: Context, msg: RebootToBootloader) -> NoReturn:
async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
from trezor import io, loop, utils
from trezor.messages import Success
from trezor.ui.layouts import confirm_action
from trezor.wire.context import get_context
await confirm_action(
ctx,
"reboot",
"Go to bootloader",
"Do you want to restart Trezor in bootloader mode?",
verb="Restart",
)
ctx = get_context()
await ctx.write(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)

@ -2,7 +2,6 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import RecoveryDevice
from trezor.wire import Context
from trezor.messages import Success
# List of RecoveryDevice fields that can be set when doing dry-run recovery.
@ -11,7 +10,7 @@ if TYPE_CHECKING:
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
async def recovery_device(msg: RecoveryDevice) -> Success:
"""
Recover BIP39/SLIP39 seed into empty device.
Recovery is also possible with replugged Trezor. We call this process Persistence.
@ -52,15 +51,14 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
# --------------------------------------------------------
if storage_recovery.is_in_progress():
return await recovery_process(ctx)
return await recovery_process()
# --------------------------------------------------------
# _continue_dialog
if not dry_run:
await confirm_reset_device(ctx, "Wallet recovery", recovery=True)
await confirm_reset_device("Wallet recovery", recovery=True)
else:
await confirm_action(
ctx,
"confirm_seedcheck",
"Seed check",
description="Do you really want to check the recovery seed?",
@ -75,14 +73,14 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
# for dry run pin needs to be entered
if dry_run:
curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
curpin, salt = await request_pin_and_sd_salt("Enter PIN")
if not config.check_pin(curpin, salt):
await error_pin_invalid(ctx)
await error_pin_invalid()
if not dry_run:
# set up pin if requested
if msg.pin_protection:
newpin = await request_pin_confirm(ctx, allow_cancel=False)
newpin = await request_pin_confirm(allow_cancel=False)
config.change_pin("", newpin, None, None)
storage_device.set_passphrase_enabled(bool(msg.passphrase_protection))
@ -95,4 +93,4 @@ async def recovery_device(ctx: Context, msg: RecoveryDevice) -> Success:
storage_recovery.set_dry_run(bool(dry_run))
workflow.set_default(recovery_homescreen)
return await recovery_process(ctx)
return await recovery_process()

@ -9,7 +9,6 @@ from .. import backup_types
from . import layout, recover
if TYPE_CHECKING:
from trezor.wire import GenericContext
from trezor.enums import BackupType
@ -21,18 +20,16 @@ async def recovery_homescreen() -> None:
workflow.set_default(homescreen)
return
# recovery process does not communicate on the wire
ctx = wire.DUMMY_CONTEXT
await recovery_process(ctx)
await recovery_process()
async def recovery_process(ctx: GenericContext) -> Success:
async def recovery_process() -> Success:
from trezor.enums import MessageType
import storage
wire.AVOID_RESTARTING_FOR = (MessageType.Initialize, MessageType.GetFeatures)
try:
return await _continue_recovery_process(ctx)
return await _continue_recovery_process()
except recover.RecoveryAborted:
dry_run = storage_recovery.is_dry_run()
if dry_run:
@ -42,7 +39,7 @@ async def recovery_process(ctx: GenericContext) -> Success:
raise wire.ActionCancelled
async def _continue_recovery_process(ctx: GenericContext) -> Success:
async def _continue_recovery_process() -> Success:
from trezor.errors import MnemonicError
# gather the current recovery state from storage
@ -58,48 +55,46 @@ async def _continue_recovery_process(ctx: GenericContext) -> Success:
if not is_first_step:
assert word_count is not None
# If we continue recovery, show starting screen with word count immediately.
await _request_share_first_screen(ctx, word_count)
await _request_share_first_screen(word_count)
secret = None
while secret is None:
if is_first_step:
# If we are starting recovery, ask for word count first...
# _request_word_count
await layout.homescreen_dialog(ctx, "Select", "Select number of words")
await layout.homescreen_dialog("Select", "Select number of words")
# ask for the number of words
word_count = await layout.request_word_count(ctx, dry_run)
word_count = await layout.request_word_count(dry_run)
# ...and only then show the starting screen with word count.
await _request_share_first_screen(ctx, word_count)
await _request_share_first_screen(word_count)
assert word_count is not None
# ask for mnemonic words one by one
words = await layout.request_mnemonic(ctx, word_count, backup_type)
words = await layout.request_mnemonic(word_count, backup_type)
# if they were invalid or some checks failed we continue and request them again
if not words:
continue
try:
secret, backup_type = await _process_words(ctx, words)
secret, backup_type = await _process_words(words)
# If _process_words succeeded, we now have both backup_type (from
# its result) and word_count (from _request_word_count earlier), which means
# that the first step is complete.
is_first_step = False
except MnemonicError:
await layout.show_invalid_mnemonic(ctx, word_count)
await layout.show_invalid_mnemonic(word_count)
assert backup_type is not None
if dry_run:
result = await _finish_recovery_dry_run(ctx, secret, backup_type)
result = await _finish_recovery_dry_run(secret, backup_type)
else:
result = await _finish_recovery(ctx, secret, backup_type)
result = await _finish_recovery(secret, backup_type)
return result
async def _finish_recovery_dry_run(
ctx: GenericContext, secret: bytes, backup_type: BackupType
) -> Success:
async def _finish_recovery_dry_run(secret: bytes, backup_type: BackupType) -> Success:
from trezor.crypto.hashlib import sha256
from trezor import utils
from apps.common import mnemonic
@ -126,7 +121,7 @@ async def _finish_recovery_dry_run(
storage_recovery.end_progress()
await layout.show_dry_run_result(ctx, result, is_slip39)
await layout.show_dry_run_result(result, is_slip39)
if result:
return Success(message="The seed is valid and matches the one in the device")
@ -134,9 +129,7 @@ async def _finish_recovery_dry_run(
raise wire.ProcessError("The seed does not match the one in the device")
async def _finish_recovery(
ctx: GenericContext, secret: bytes, backup_type: BackupType
) -> Success:
async def _finish_recovery(secret: bytes, backup_type: BackupType) -> Success:
from trezor.ui.layouts import show_success
from trezor.enums import BackupType
@ -157,15 +150,11 @@ async def _finish_recovery(
storage_recovery.end_progress()
await show_success(
ctx, "success_recovery", "You have finished recovering your wallet."
)
await show_success("success_recovery", "You have finished recovering your wallet.")
return Success(message="Device recovered")
async def _process_words(
ctx: GenericContext, words: str
) -> tuple[bytes | None, BackupType]:
async def _process_words(words: str) -> tuple[bytes | None, BackupType]:
word_count = len(words.split(" "))
is_slip39 = backup_types.is_slip39_word_count(word_count)
@ -179,28 +168,28 @@ async def _process_words(
if secret is None: # SLIP-39
assert share is not None
if share.group_count and share.group_count > 1:
await layout.show_group_share_success(ctx, share.index, share.group_index)
await _request_share_next_screen(ctx)
await layout.show_group_share_success(share.index, share.group_index)
await _request_share_next_screen()
return secret, backup_type
async def _request_share_first_screen(ctx: GenericContext, word_count: int) -> None:
async def _request_share_first_screen(word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares()
if remaining:
await _request_share_next_screen(ctx)
await _request_share_next_screen()
else:
await layout.homescreen_dialog(
ctx, "Enter share", "Enter any share", f"({word_count} words)"
"Enter share", "Enter any share", f"({word_count} words)"
)
else: # BIP-39
await layout.homescreen_dialog(
ctx, "Enter seed", "Enter recovery seed", f"({word_count} words)"
"Enter seed", "Enter recovery seed", f"({word_count} words)"
)
async def _request_share_next_screen(ctx: GenericContext) -> None:
async def _request_share_next_screen() -> None:
from trezor import strings
remaining = storage_recovery.fetch_slip39_remaining_shares()
@ -211,17 +200,16 @@ async def _request_share_next_screen(ctx: GenericContext) -> None:
if group_count > 1:
await layout.homescreen_dialog(
ctx,
"Enter",
"More shares needed",
info_func=_show_remaining_groups_and_shares,
)
else:
text = strings.format_plural("{count} more {plural}", remaining[0], "share")
await layout.homescreen_dialog(ctx, "Enter share", text, "needed to enter")
await layout.homescreen_dialog("Enter share", text, "needed to enter")
async def _show_remaining_groups_and_shares(ctx: GenericContext) -> None:
async def _show_remaining_groups_and_shares() -> None:
"""
Show info dialog for Slip39 Advanced - what shares are to be entered.
"""
@ -254,5 +242,5 @@ async def _show_remaining_groups_and_shares(ctx: GenericContext) -> None:
assert share # share needs to be set
return await layout.show_remaining_shares(
ctx, groups, shares_remaining, share.group_threshold
groups, shares_remaining, share.group_threshold
)

@ -14,13 +14,11 @@ from .. import backup_types
if TYPE_CHECKING:
from typing import Callable
from trezor.enums import BackupType
from trezor.wire import GenericContext
async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None:
async def _confirm_abort(dry_run: bool = False) -> None:
if dry_run:
await confirm_action(
ctx,
"abort_recovery",
"Abort seed check",
description="Do you really want to abort the seed check?",
@ -28,7 +26,6 @@ async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None:
)
else:
await confirm_action(
ctx,
"abort_recovery",
"Abort recovery",
"All progress will be lost.",
@ -39,21 +36,21 @@ async def _confirm_abort(ctx: GenericContext, dry_run: bool = False) -> None:
async def request_mnemonic(
ctx: GenericContext, word_count: int, backup_type: BackupType | None
word_count: int, backup_type: BackupType | None
) -> str | None:
from . import word_validity
from trezor.ui.layouts.common import button_request
from trezor.ui.layouts.recovery import request_word
from trezor.ui.layouts import mnemonic_word_entering
await mnemonic_word_entering(ctx)
await mnemonic_word_entering()
await button_request(ctx, "mnemonic", code=ButtonRequestType.MnemonicInput)
await button_request("mnemonic", code=ButtonRequestType.MnemonicInput)
words: list[str] = []
for i in range(word_count):
word = await request_word(
ctx, i, word_count, is_slip39=backup_types.is_slip39_word_count(word_count)
i, word_count, is_slip39=backup_types.is_slip39_word_count(word_count)
)
words.append(word)
@ -62,7 +59,6 @@ async def request_mnemonic(
except word_validity.AlreadyAdded:
# show_share_already_added
await show_recovery_warning(
ctx,
"warning_known_share",
"Share already entered, please enter a different share.",
)
@ -70,7 +66,6 @@ async def request_mnemonic(
except word_validity.IdentifierMismatch:
# show_identifier_mismatch
await show_recovery_warning(
ctx,
"warning_mismatched_share",
"You have entered a share from another Shamir Backup.",
)
@ -78,7 +73,6 @@ async def request_mnemonic(
except word_validity.ThresholdReached:
# show_group_threshold_reached
await show_recovery_warning(
ctx,
"warning_group_threshold",
"Threshold of this group has been reached. Input share from different group.",
)
@ -87,9 +81,7 @@ async def request_mnemonic(
return " ".join(words)
async def show_dry_run_result(
ctx: GenericContext, result: bool, is_slip39: bool
) -> None:
async def show_dry_run_result(result: bool, is_slip39: bool) -> None:
from trezor.ui.layouts import show_success
if result:
@ -99,34 +91,29 @@ async def show_dry_run_result(
text = (
"The entered recovery seed is valid and matches the one in the device."
)
await show_success(ctx, "success_dry_recovery", text, button="Continue")
await show_success("success_dry_recovery", text, button="Continue")
else:
if is_slip39:
text = "The entered recovery shares are valid but do not match what is currently in the device."
else:
text = "The entered recovery seed is valid but does not match the one in the device."
await show_recovery_warning(
ctx, "warning_dry_recovery", text, button="Continue"
)
await show_recovery_warning("warning_dry_recovery", text, button="Continue")
async def show_invalid_mnemonic(ctx: GenericContext, word_count: int) -> None:
async def show_invalid_mnemonic(word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count):
await show_recovery_warning(
ctx,
"warning_invalid_share",
"You have entered an invalid recovery share.",
)
else:
await show_recovery_warning(
ctx,
"warning_invalid_seed",
"You have entered an invalid recovery seed.",
)
async def homescreen_dialog(
ctx: GenericContext,
button_label: str,
text: str,
subtext: str | None = None,
@ -139,14 +126,12 @@ async def homescreen_dialog(
while True:
dry_run = storage_recovery.is_dry_run()
if await continue_recovery(
ctx, button_label, text, subtext, info_func, dry_run
):
if await continue_recovery(button_label, text, subtext, info_func, dry_run):
# go forward in the recovery process
break
# user has chosen to abort, confirm the choice
try:
await _confirm_abort(ctx, dry_run)
await _confirm_abort(dry_run)
except ActionCancelled:
pass
else:

@ -13,7 +13,6 @@ if __debug__:
if TYPE_CHECKING:
from trezor.messages import ResetDevice
from trezor.wire import Context
from trezor.messages import Success
@ -23,7 +22,7 @@ BAK_T_SLIP39_ADVANCED = BackupType.Slip39_Advanced # global_import_cache
_DEFAULT_BACKUP_TYPE = BAK_T_BIP39
async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
async def reset_device(msg: ResetDevice) -> Success:
from trezor import config, utils
from apps.common.request_pin import request_pin_confirm
from trezor.ui.layouts import (
@ -33,6 +32,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
from trezor.crypto import bip39, random
from trezor.messages import Success, EntropyAck, EntropyRequest
from trezor.pin import render_empty_loader
from trezor.wire.context import call
backup_type = msg.backup_type # local_cache_attribute
@ -49,7 +49,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
title = f"Create wallet{delimiter}(Super Shamir)"
else:
title = "Create wallet"
await confirm_reset_device(ctx, title)
await confirm_reset_device(title)
# Rendering empty loader so users do not feel a freezing screen
render_empty_loader("PROCESSING", "")
@ -59,7 +59,7 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# request and set new PIN
if msg.pin_protection:
newpin = await request_pin_confirm(ctx)
newpin = await request_pin_confirm()
if not config.change_pin("", newpin, None, None):
raise ProcessError("Failed to set PIN")
@ -68,10 +68,10 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
if __debug__:
storage.debug.reset_internal_entropy = int_entropy
if msg.display_random:
await layout.show_internal_entropy(ctx, int_entropy)
await layout.show_internal_entropy(int_entropy)
# request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
entropy_ack = await call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy
# For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
@ -94,11 +94,11 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# If doing backup, ask the user to confirm.
if perform_backup:
perform_backup = await confirm_backup(ctx)
perform_backup = await confirm_backup()
# generate and display backup information for the master secret
if perform_backup:
await backup_seed(ctx, backup_type, secret)
await backup_seed(backup_type, secret)
# write settings and master secret into storage
if msg.label is not None:
@ -113,19 +113,19 @@ async def reset_device(ctx: Context, msg: ResetDevice) -> Success:
# if we backed up the wallet, show success message
if perform_backup:
await layout.show_backup_success(ctx)
await layout.show_backup_success()
return Success(message="Initialized")
async def _backup_slip39_basic(ctx: Context, encrypted_master_secret: bytes) -> None:
async def _backup_slip39_basic(encrypted_master_secret: bytes) -> None:
# get number of shares
await layout.slip39_show_checklist(ctx, 0, BAK_T_SLIP39_BASIC)
shares_count = await layout.slip39_prompt_number_of_shares(ctx)
await layout.slip39_show_checklist(0, BAK_T_SLIP39_BASIC)
shares_count = await layout.slip39_prompt_number_of_shares()
# get threshold
await layout.slip39_show_checklist(ctx, 1, BAK_T_SLIP39_BASIC)
threshold = await layout.slip39_prompt_threshold(ctx, shares_count)
await layout.slip39_show_checklist(1, BAK_T_SLIP39_BASIC)
threshold = await layout.slip39_prompt_threshold(shares_count)
identifier = storage_device.get_slip39_identifier()
iteration_exponent = storage_device.get_slip39_iteration_exponent()
@ -142,27 +142,25 @@ async def _backup_slip39_basic(ctx: Context, encrypted_master_secret: bytes) ->
)[0]
# show and confirm individual shares
await layout.slip39_show_checklist(ctx, 2, BAK_T_SLIP39_BASIC)
await layout.slip39_basic_show_and_confirm_shares(ctx, mnemonics)
await layout.slip39_show_checklist(2, BAK_T_SLIP39_BASIC)
await layout.slip39_basic_show_and_confirm_shares(mnemonics)
async def _backup_slip39_advanced(ctx: Context, encrypted_master_secret: bytes) -> None:
async def _backup_slip39_advanced(encrypted_master_secret: bytes) -> None:
# get number of groups
await layout.slip39_show_checklist(ctx, 0, BAK_T_SLIP39_ADVANCED)
groups_count = await layout.slip39_advanced_prompt_number_of_groups(ctx)
await layout.slip39_show_checklist(0, BAK_T_SLIP39_ADVANCED)
groups_count = await layout.slip39_advanced_prompt_number_of_groups()
# get group threshold
await layout.slip39_show_checklist(ctx, 1, BAK_T_SLIP39_ADVANCED)
group_threshold = await layout.slip39_advanced_prompt_group_threshold(
ctx, groups_count
)
await layout.slip39_show_checklist(1, BAK_T_SLIP39_ADVANCED)
group_threshold = await layout.slip39_advanced_prompt_group_threshold(groups_count)
# get shares and thresholds
await layout.slip39_show_checklist(ctx, 2, BAK_T_SLIP39_ADVANCED)
await layout.slip39_show_checklist(2, BAK_T_SLIP39_ADVANCED)
groups = []
for i in range(groups_count):
share_count = await layout.slip39_prompt_number_of_shares(ctx, i)
share_threshold = await layout.slip39_prompt_threshold(ctx, share_count, i)
share_count = await layout.slip39_prompt_number_of_shares(i)
share_threshold = await layout.slip39_prompt_threshold(share_count, i)
groups.append((share_threshold, share_count))
identifier = storage_device.get_slip39_identifier()
@ -180,7 +178,7 @@ async def _backup_slip39_advanced(ctx: Context, encrypted_master_secret: bytes)
)
# show and confirm individual shares
await layout.slip39_advanced_show_and_confirm_shares(ctx, mnemonics)
await layout.slip39_advanced_show_and_confirm_shares(mnemonics)
def _validate_reset_device(msg: ResetDevice) -> None:
@ -222,12 +220,10 @@ def _compute_secret_from_entropy(
return secret
async def backup_seed(
ctx: Context, backup_type: BackupType, mnemonic_secret: bytes
) -> None:
async def backup_seed(backup_type: BackupType, mnemonic_secret: bytes) -> None:
if backup_type == BAK_T_SLIP39_BASIC:
await _backup_slip39_basic(ctx, mnemonic_secret)
await _backup_slip39_basic(mnemonic_secret)
elif backup_type == BAK_T_SLIP39_ADVANCED:
await _backup_slip39_advanced(ctx, mnemonic_secret)
await _backup_slip39_advanced(mnemonic_secret)
else:
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic_secret.decode())
await layout.bip39_show_and_confirm_mnemonic(mnemonic_secret.decode())

@ -1,5 +1,5 @@
from micropython import const
from typing import TYPE_CHECKING
from typing import Sequence
from trezor.enums import ButtonRequestType
from trezor.ui.layouts import show_success
@ -12,18 +12,13 @@ from trezor.ui.layouts.reset import ( # noqa: F401
slip39_show_checklist,
)
if TYPE_CHECKING:
from typing import Sequence
from trezor.wire import GenericContext
_NUM_OF_CHOICES = const(3)
async def show_internal_entropy(ctx: GenericContext, entropy: bytes) -> None:
async def show_internal_entropy(entropy: bytes) -> None:
from trezor.ui.layouts import confirm_blob
await confirm_blob(
ctx,
"entropy",
"Internal entropy",
entropy,
@ -32,7 +27,6 @@ async def show_internal_entropy(ctx: GenericContext, entropy: bytes) -> None:
async def _confirm_word(
ctx: GenericContext,
share_index: int | None,
share_words: Sequence[str],
offset: int,
@ -56,14 +50,13 @@ async def _confirm_word(
random.shuffle(choices)
# let the user pick a word
selected_word: str = await select_word(
ctx, choices, share_index, checked_index, count, group_index
choices, share_index, checked_index, count, group_index
)
# confirm it is the correct one
return selected_word == checked_word
async def _share_words_confirmed(
ctx: GenericContext,
share_index: int | None,
share_words: Sequence[str],
num_of_shares: int | None = None,
@ -77,22 +70,20 @@ async def _share_words_confirmed(
"""
# TODO: confirm_action("Select the words bla bla")
if await _do_confirm_share_words(ctx, share_index, share_words, group_index):
if await _do_confirm_share_words(share_index, share_words, group_index):
await _show_confirmation_success(
ctx,
share_index,
num_of_shares,
group_index,
)
return True
else:
await _show_confirmation_failure(ctx)
await _show_confirmation_failure()
return False
async def _do_confirm_share_words(
ctx: GenericContext,
share_index: int | None,
share_words: Sequence[str],
group_index: int | None = None,
@ -106,7 +97,7 @@ async def _do_confirm_share_words(
offset = 0
count = len(share_words)
for part in utils.chunks(share_words, third):
if not await _confirm_word(ctx, share_index, part, offset, count, group_index):
if not await _confirm_word(share_index, part, offset, count, group_index):
return False
offset += len(part)
@ -114,7 +105,6 @@ async def _do_confirm_share_words(
async def _show_confirmation_success(
ctx: GenericContext,
share_index: int | None = None,
num_of_shares: int | None = None,
group_index: int | None = None,
@ -138,14 +128,13 @@ async def _show_confirmation_success(
subheader = f"Group {group_index + 1} - Share {share_index + 1} checked successfully."
text = "Continue with the next share."
return await show_success(ctx, "success_recovery", text, subheader)
return await show_success("success_recovery", text, subheader)
async def _show_confirmation_failure(ctx: GenericContext) -> None:
async def _show_confirmation_failure() -> None:
from trezor.ui.layouts.recovery import show_recovery_warning
await show_recovery_warning(
ctx,
"warning_backup_check",
"Please check again.",
"That is the wrong word.",
@ -154,34 +143,34 @@ async def _show_confirmation_failure(ctx: GenericContext) -> None:
)
async def show_backup_warning(ctx: GenericContext, slip39: bool = False) -> None:
async def show_backup_warning(slip39: bool = False) -> None:
from trezor.ui.layouts.reset import show_warning_backup
await show_warning_backup(ctx, slip39)
await show_warning_backup(slip39)
async def show_backup_success(ctx: GenericContext) -> None:
async def show_backup_success() -> None:
from trezor.ui.layouts.reset import show_success_backup
await show_success_backup(ctx)
await show_success_backup()
# BIP39
# ===
async def bip39_show_and_confirm_mnemonic(ctx: GenericContext, mnemonic: str) -> None:
async def bip39_show_and_confirm_mnemonic(mnemonic: str) -> None:
# warn user about mnemonic safety
await show_backup_warning(ctx)
await show_backup_warning()
words = mnemonic.split()
while True:
# display paginated mnemonic on the screen
await show_share_words(ctx, words)
await show_share_words(words)
# make the user confirm some words from the mnemonic
if await _share_words_confirmed(ctx, None, words):
if await _share_words_confirmed(None, words):
break # this share is confirmed, go to next one
@ -189,38 +178,36 @@ async def bip39_show_and_confirm_mnemonic(ctx: GenericContext, mnemonic: str) ->
# ===
async def slip39_basic_show_and_confirm_shares(
ctx: GenericContext, shares: Sequence[str]
) -> None:
async def slip39_basic_show_and_confirm_shares(shares: Sequence[str]) -> None:
# warn user about mnemonic safety
await show_backup_warning(ctx, True)
await show_backup_warning(True)
for index, share in enumerate(shares):
share_words = share.split(" ")
while True:
# display paginated share on the screen
await show_share_words(ctx, share_words, index)
await show_share_words(share_words, index)
# make the user confirm words from the share
if await _share_words_confirmed(ctx, index, share_words, len(shares)):
if await _share_words_confirmed(index, share_words, len(shares)):
break # this share is confirmed, go to next one
async def slip39_advanced_show_and_confirm_shares(
ctx: GenericContext, shares: Sequence[Sequence[str]]
shares: Sequence[Sequence[str]],
) -> None:
# warn user about mnemonic safety
await show_backup_warning(ctx, True)
await show_backup_warning(True)
for group_index, group in enumerate(shares):
for share_index, share in enumerate(group):
share_words = share.split(" ")
while True:
# display paginated share on the screen
await show_share_words(ctx, share_words, share_index, group_index)
await show_share_words(share_words, share_index, group_index)
# make the user confirm words from the share
if await _share_words_confirmed(
ctx, share_index, share_words, len(group), group_index
share_index, share_words, len(group), group_index
):
break # this share is confirmed, go to next one

@ -14,7 +14,6 @@ from apps.common.sdcard import ensure_sdcard
if TYPE_CHECKING:
from typing import Awaitable
from trezor.messages import SdProtect
from trezor.wire import Context
def _make_salt() -> tuple[bytes, bytes, bytes]:
@ -26,56 +25,54 @@ def _make_salt() -> tuple[bytes, bytes, bytes]:
return salt, auth_key, tag
async def _set_salt(
ctx: Context, salt: bytes, salt_tag: bytes, stage: bool = False
) -> None:
async def _set_salt(salt: bytes, salt_tag: bytes, stage: bool = False) -> None:
from apps.common.sdcard import confirm_retry_sd
while True:
await ensure_sdcard(ctx)
await ensure_sdcard()
try:
return storage_sd_salt.set_sd_salt(salt, salt_tag, stage)
except OSError:
await confirm_retry_sd(ctx, ProcessError("SD card I/O error."))
await confirm_retry_sd(ProcessError("SD card I/O error."))
async def sd_protect(ctx: Context, msg: SdProtect) -> Success:
async def sd_protect(msg: SdProtect) -> Success:
from trezor.wire import NotInitialized
if not storage_device.is_initialized():
raise NotInitialized("Device is not initialized")
if msg.operation == SdProtectOperationType.ENABLE:
return await _sd_protect_enable(ctx, msg)
return await _sd_protect_enable(msg)
elif msg.operation == SdProtectOperationType.DISABLE:
return await _sd_protect_disable(ctx, msg)
return await _sd_protect_disable(msg)
elif msg.operation == SdProtectOperationType.REFRESH:
return await _sd_protect_refresh(ctx, msg)
return await _sd_protect_refresh(msg)
else:
raise ProcessError("Unknown operation")
async def _sd_protect_enable(ctx: Context, msg: SdProtect) -> Success:
async def _sd_protect_enable(msg: SdProtect) -> Success:
from apps.common.request_pin import request_pin
if storage_sd_salt.is_enabled():
raise ProcessError("SD card protection already enabled")
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
await require_confirm_sd_protect(msg)
# Make sure SD card is present.
await ensure_sdcard(ctx)
await ensure_sdcard()
# Get the current PIN.
if config.has_pin():
pin = await request_pin(ctx, "Enter PIN", config.get_pin_rem())
pin = await request_pin("Enter PIN", config.get_pin_rem())
else:
pin = ""
# Check PIN and prepare salt file.
salt, salt_auth_key, salt_tag = _make_salt()
await _set_salt(ctx, salt, salt_tag)
await _set_salt(salt, salt_tag)
if not config.change_pin(pin, pin, None, salt):
# Wrong PIN. Clean up the prepared salt file.
@ -86,17 +83,15 @@ async def _sd_protect_enable(ctx: Context, msg: SdProtect) -> Success:
# SD-protection. If it fails for any reason, we suppress the
# exception, because primarily we need to raise wire.PinInvalid.
pass
await error_pin_invalid(ctx)
await error_pin_invalid()
storage_device.set_sd_salt_auth_key(salt_auth_key)
await show_success(
ctx, "success_sd", "You have successfully enabled SD protection."
)
await show_success("success_sd", "You have successfully enabled SD protection.")
return Success(message="SD card protection enabled")
async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success:
async def _sd_protect_disable(msg: SdProtect) -> Success:
if not storage_sd_salt.is_enabled():
raise ProcessError("SD card protection not enabled")
@ -104,14 +99,14 @@ async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success:
# protection. The cleanup will not happen in such case, but that does not matter.
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
await require_confirm_sd_protect(msg)
# Get the current PIN and salt from the SD card.
pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
pin, salt = await request_pin_and_sd_salt("Enter PIN")
# Check PIN and remove salt.
if not config.change_pin(pin, pin, salt, None):
await error_pin_invalid(ctx)
await error_pin_invalid()
storage_device.set_sd_salt_auth_key(None)
@ -124,31 +119,29 @@ async def _sd_protect_disable(ctx: Context, msg: SdProtect) -> Success:
# because overall SD-protection was successfully disabled.
pass
await show_success(
ctx, "success_sd", "You have successfully disabled SD protection."
)
await show_success("success_sd", "You have successfully disabled SD protection.")
return Success(message="SD card protection disabled")
async def _sd_protect_refresh(ctx: Context, msg: SdProtect) -> Success:
async def _sd_protect_refresh(msg: SdProtect) -> Success:
if not storage_sd_salt.is_enabled():
raise ProcessError("SD card protection not enabled")
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
await require_confirm_sd_protect(msg)
# Make sure SD card is present.
await ensure_sdcard(ctx)
await ensure_sdcard()
# Get the current PIN and salt from the SD card.
pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
pin, old_salt = await request_pin_and_sd_salt("Enter PIN")
# Check PIN and change salt.
new_salt, new_auth_key, new_salt_tag = _make_salt()
await _set_salt(ctx, new_salt, new_salt_tag, stage=True)
await _set_salt(new_salt, new_salt_tag, stage=True)
if not config.change_pin(pin, pin, old_salt, new_salt):
await error_pin_invalid(ctx)
await error_pin_invalid()
storage_device.set_sd_salt_auth_key(new_auth_key)
@ -161,13 +154,11 @@ async def _sd_protect_refresh(ctx: Context, msg: SdProtect) -> Success:
# SD-protection was successfully refreshed.
pass
await show_success(
ctx, "success_sd", "You have successfully refreshed SD protection."
)
await show_success("success_sd", "You have successfully refreshed SD protection.")
return Success(message="SD card protection refreshed")
def require_confirm_sd_protect(ctx: Context, msg: SdProtect) -> Awaitable[None]:
def require_confirm_sd_protect(msg: SdProtect) -> Awaitable[None]:
from trezor.ui.layouts import confirm_action
if msg.operation == SdProtectOperationType.ENABLE:
@ -179,4 +170,4 @@ def require_confirm_sd_protect(ctx: Context, msg: SdProtect) -> Awaitable[None]:
else:
raise ProcessError("Unknown operation")
return confirm_action(ctx, "set_sd", "SD card protection", description=text)
return confirm_action("set_sd", "SD card protection", description=text)

@ -2,10 +2,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import SetU2FCounter, Success
from trezor.wire import Context
async def set_u2f_counter(ctx: Context, msg: SetU2FCounter) -> Success:
async def set_u2f_counter(msg: SetU2FCounter) -> Success:
import storage.device as storage_device
from trezor import wire
from trezor.enums import ButtonRequestType
@ -18,7 +17,6 @@ async def set_u2f_counter(ctx: Context, msg: SetU2FCounter) -> Success:
raise wire.ProcessError("No value provided")
await confirm_action(
ctx,
"set_u2f_counter",
"Set U2F counter",
description="Do you really want to set the U2F counter to {}?",

@ -2,16 +2,15 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import ShowDeviceTutorial, Success
from trezor.wire import Context
async def show_tutorial(ctx: Context, msg: ShowDeviceTutorial) -> Success:
async def show_tutorial(msg: ShowDeviceTutorial) -> Success:
from trezor.messages import Success
# NOTE: tutorial is defined only for TR, and this function should
# also be called only in case of TR
from trezor.ui.layouts import tutorial
await tutorial(ctx)
await tutorial()
return Success(message="Tutorial shown")

@ -1,11 +1,10 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire import GenericContext
from trezor.messages import WipeDevice, Success
async def wipe_device(ctx: GenericContext, msg: WipeDevice) -> Success:
async def wipe_device(msg: WipeDevice) -> Success:
import storage
from trezor.enums import ButtonRequestType
from trezor.messages import Success
@ -14,7 +13,6 @@ async def wipe_device(ctx: GenericContext, msg: WipeDevice) -> Success:
from apps.base import reload_settings_from_storage
await confirm_action(
ctx,
"confirm_wipe",
"Wipe device",
"All data will be erased.",

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import CipherKeyValue, CipheredKeyValue
from trezor.wire import Context
# This module implements the SLIP-0011 symmetric encryption of key-value pairs using a
# deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0011.md.
async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValue:
async def cipher_key_value(msg: CipherKeyValue) -> CipheredKeyValue:
from trezor.wire import DataError
from trezor.messages import CipheredKeyValue
from trezor.crypto import aes, hmac
@ -16,7 +15,7 @@ async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValu
from apps.common.paths import AlwaysMatchingSchema
from trezor.ui.layouts import confirm_action
keychain = await get_keychain(ctx, "secp256k1", [AlwaysMatchingSchema])
keychain = await get_keychain("secp256k1", [AlwaysMatchingSchema])
if len(msg.value) % 16 > 0:
raise DataError("Value length must be a multiple of 16")
@ -35,9 +34,7 @@ async def cipher_key_value(ctx: Context, msg: CipherKeyValue) -> CipheredKeyValu
title = "Decrypt value"
verb = "CONFIRM"
await confirm_action(
ctx, "cipher_key_value", title, description=msg.key, verb=verb
)
await confirm_action("cipher_key_value", title, description=msg.key, verb=verb)
node = keychain.derive(msg.address_n)

@ -8,7 +8,6 @@ from apps.common.paths import PathSchema, unharden
if TYPE_CHECKING:
from trezor.messages import CosiCommit
from trezor.wire import Context
# This module implements the cosigner part of the CoSi collective signatures
# as described in https://dedis.cs.yale.edu/dissent/papers/witness.pdf
@ -55,16 +54,17 @@ def _decode_path(address_n: list[int]) -> str | None:
return None
async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature:
async def cosi_commit(msg: CosiCommit) -> CosiSignature:
import storage.cache as storage_cache
from trezor.crypto import cosi
from trezor.crypto.curve import ed25519
from trezor.ui.layouts import confirm_blob, confirm_text
from trezor.wire.context import call
from apps.common import paths
from apps.common.keychain import get_keychain
keychain = await get_keychain(ctx, "ed25519", [SCHEMA_SLIP18, SCHEMA_SLIP26])
await paths.validate_path(ctx, keychain, msg.address_n)
keychain = await get_keychain("ed25519", [SCHEMA_SLIP18, SCHEMA_SLIP26])
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
seckey = node.private_key()
@ -78,7 +78,7 @@ async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature:
if commitment is None:
raise RuntimeError
sign_msg = await ctx.call(
sign_msg = await call(
CosiCommitment(commitment=commitment, pubkey=pubkey), CosiSign
)
@ -87,14 +87,12 @@ async def cosi_commit(ctx: Context, msg: CosiCommit) -> CosiSignature:
path_description = _decode_path(sign_msg.address_n)
await confirm_text(
ctx,
"cosi_confirm_key",
"COSI KEYS",
paths.address_n_to_str(sign_msg.address_n),
path_description,
)
await confirm_blob(
ctx,
"cosi_sign",
"COSI DATA",
sign_msg.data,

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import GetECDHSessionKey, ECDHSessionKey
from trezor.wire import Context
# This module implements the SLIP-0017 Elliptic Curve Diffie-Hellman algorithm, using a
# deterministic hierarchy, see https://github.com/satoshilabs/slips/blob/master/slip-0017.md.
async def get_ecdh_session_key(ctx: Context, msg: GetECDHSessionKey) -> ECDHSessionKey:
async def get_ecdh_session_key(msg: GetECDHSessionKey) -> ECDHSessionKey:
from trezor.ui.layouts import confirm_address
from .sign_identity import (
get_identity_path,
@ -24,13 +23,12 @@ async def get_ecdh_session_key(ctx: Context, msg: GetECDHSessionKey) -> ECDHSess
peer_public_key = msg.peer_public_key # local_cache_attribute
curve_name = msg.ecdsa_curve_name or "secp256k1"
keychain = await get_keychain(ctx, curve_name, [AlwaysMatchingSchema])
keychain = await get_keychain(curve_name, [AlwaysMatchingSchema])
identity = serialize_identity(msg_identity)
# require_confirm_ecdh_session_key
proto = msg_identity.proto.upper() if msg_identity.proto else "identity"
await confirm_address(
ctx,
f"Decrypt {proto}",
serialize_identity_without_proto(msg_identity),
None,

@ -1,18 +1,16 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import GetEntropy, Entropy
async def get_entropy(ctx: Context, msg: GetEntropy) -> Entropy:
async def get_entropy(msg: GetEntropy) -> Entropy:
from trezor.crypto import random
from trezor.enums import ButtonRequestType
from trezor.messages import Entropy
from trezor.ui.layouts import confirm_action
await confirm_action(
ctx,
"get_entropy",
"Confirm entropy",
"Do you really want to send entropy?",

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import GetFirmwareHash, FirmwareHash
from trezor.wire import Context
from trezor.ui.layouts.common import ProgressLayout
_progress_obj: ProgressLayout | None = None
async def get_firmware_hash(ctx: Context, msg: GetFirmwareHash) -> FirmwareHash:
async def get_firmware_hash(msg: GetFirmwareHash) -> FirmwareHash:
from trezor.messages import FirmwareHash
from trezor.utils import firmware_hash
from trezor.ui.layouts.progress import progress

@ -6,14 +6,13 @@ from apps.common import coininfo
if TYPE_CHECKING:
from trezor.messages import IdentityType, SignIdentity, SignedIdentity
from trezor.wire import Context
from apps.common.paths import Bip32Path
# This module implements the SLIP-0013 authentication using a deterministic hierarchy, see
# https://github.com/satoshilabs/slips/blob/master/slip-0013.md.
async def sign_identity(ctx: Context, msg: SignIdentity) -> SignedIdentity:
async def sign_identity(msg: SignIdentity) -> SignedIdentity:
from trezor.messages import SignedIdentity
from trezor.ui.layouts import confirm_sign_identity
from apps.common.keychain import get_keychain
@ -25,13 +24,13 @@ async def sign_identity(ctx: Context, msg: SignIdentity) -> SignedIdentity:
challenge_hidden = msg.challenge_hidden # local_cache_attribute
curve_name = msg.ecdsa_curve_name or "secp256k1"
keychain = await get_keychain(ctx, curve_name, [AlwaysMatchingSchema])
keychain = await get_keychain(curve_name, [AlwaysMatchingSchema])
identity = serialize_identity(msg_identity)
# require_confirm_sign_identity
proto = msg_identity_proto.upper() if msg_identity_proto else "identity"
await confirm_sign_identity(
ctx, proto, serialize_identity_without_proto(msg_identity), challenge_visual
proto, serialize_identity_without_proto(msg_identity), challenge_visual
)
# END require_confirm_sign_identity

@ -44,7 +44,7 @@ if __debug__:
return Failure(**kwargs)
async def diag(ctx, msg, **kwargs) -> Failure:
async def diag(msg, **kwargs) -> Failure:
ins = msg.ins # local_cache_attribute
debug = log.debug # local_cache_attribute

@ -4,15 +4,12 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import MoneroGetAddress, MoneroAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def get_address(
ctx: Context, msg: MoneroGetAddress, keychain: Keychain
) -> MoneroAddress:
async def get_address(msg: MoneroGetAddress, keychain: Keychain) -> MoneroAddress:
from trezor import wire
from trezor.messages import MoneroAddress
from trezor.ui.layouts import show_address
@ -26,7 +23,7 @@ async def get_address(
minor = msg.minor # local_cache_attribute
payment_id = msg.payment_id # local_cache_attribute
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
addr = creds.address
@ -69,7 +66,6 @@ async def get_address(
if msg.show_display:
await show_address(
ctx,
addr,
address_qr="monero:" + addr,
path=paths.address_n_to_str(msg.address_n),

@ -24,12 +24,11 @@ _GET_TX_KEY_REASON_TX_DERIVATION = const(1)
if TYPE_CHECKING:
from trezor.messages import MoneroGetTxKeyRequest, MoneroGetTxKeyAck
from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__)
async def get_tx_keys(
ctx: Context, msg: MoneroGetTxKeyRequest, keychain: Keychain
msg: MoneroGetTxKeyRequest, keychain: Keychain
) -> MoneroGetTxKeyAck:
from trezor import utils, wire
from trezor.messages import MoneroGetTxKeyAck
@ -38,10 +37,10 @@ async def get_tx_keys(
from apps.monero import layout, misc
from apps.monero.xmr import chacha_poly, crypto, crypto_helpers
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await layout.require_confirm_tx_key(ctx, export_key=not do_deriv)
await layout.require_confirm_tx_key(export_key=not do_deriv)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -3,24 +3,21 @@ from typing import TYPE_CHECKING
from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import MoneroGetWatchKey, MoneroWatchKey
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def get_watch_only(
ctx: Context, msg: MoneroGetWatchKey, keychain: Keychain
) -> MoneroWatchKey:
async def get_watch_only(msg: MoneroGetWatchKey, keychain: Keychain) -> MoneroWatchKey:
from apps.common import paths
from apps.monero import layout, misc
from apps.monero.xmr import crypto_helpers
from trezor.messages import MoneroWatchKey
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
await layout.require_confirm_watchkey(ctx)
await layout.require_confirm_watchkey()
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
address = creds.address

@ -14,7 +14,6 @@ if TYPE_CHECKING:
MoneroKeyImageSyncStepRequest,
)
from trezor.ui.layouts.common import ProgressLayout
from trezor.wire import Context
from apps.common.keychain import Keychain
@ -23,7 +22,7 @@ if TYPE_CHECKING:
@auto_keychain(__name__)
async def key_image_sync(
ctx: Context, msg: MoneroKeyImageExportInitRequest, keychain: Keychain
msg: MoneroKeyImageExportInitRequest, keychain: Keychain
) -> MoneroKeyImageSyncFinalAck:
import gc
from trezor.messages import (
@ -31,16 +30,17 @@ async def key_image_sync(
MoneroKeyImageSyncFinalRequest,
MoneroKeyImageSyncStepRequest,
)
from trezor.wire.context import call
state = KeyImageSync()
res = await _init_step(state, ctx, msg, keychain)
res = await _init_step(state, msg, keychain)
progress = layout.monero_keyimage_sync_progress()
while state.current_output + 1 < state.num_outputs:
step = await ctx.call(res, MoneroKeyImageSyncStepRequest)
res = _sync_step(state, ctx, step, progress)
step = await call(res, MoneroKeyImageSyncStepRequest)
res = _sync_step(state, step, progress)
gc.collect()
await ctx.call(res, MoneroKeyImageSyncFinalRequest)
await call(res, MoneroKeyImageSyncFinalRequest)
# _final_step
if state.current_output + 1 != state.num_outputs:
@ -66,7 +66,6 @@ class KeyImageSync:
async def _init_step(
s: KeyImageSync,
ctx: Context,
msg: MoneroKeyImageExportInitRequest,
keychain: Keychain,
) -> MoneroKeyImageExportInitAck:
@ -76,11 +75,11 @@ async def _init_step(
from apps.monero.xmr import monero
from apps.monero import misc
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
await layout.require_confirm_keyimage_sync(ctx)
await layout.require_confirm_keyimage_sync()
s.num_outputs = msg.num
s.expected_hash = msg.hash
@ -96,7 +95,6 @@ async def _init_step(
def _sync_step(
s: KeyImageSync,
ctx: Context,
tds: MoneroKeyImageSyncStepRequest,
progress: ProgressLayout,
) -> MoneroKeyImageSyncStepAck:

@ -17,7 +17,6 @@ if TYPE_CHECKING:
MoneroTransactionData,
MoneroTransactionDestinationEntry,
)
from trezor.wire import Context
from .signing.state import State
@ -60,9 +59,8 @@ def _format_amount(value: int) -> str:
return f"{strings.format_amount(value, 12)} XMR"
async def require_confirm_watchkey(ctx: Context) -> None:
async def require_confirm_watchkey() -> None:
await confirm_action(
ctx,
"get_watchkey",
"Confirm export",
description="Do you really want to export watch-only credentials?",
@ -70,9 +68,8 @@ async def require_confirm_watchkey(ctx: Context) -> None:
)
async def require_confirm_keyimage_sync(ctx: Context) -> None:
async def require_confirm_keyimage_sync() -> None:
await confirm_action(
ctx,
"key_image_sync",
"Confirm ki sync",
description="Do you really want to\nsync key images?",
@ -80,9 +77,8 @@ async def require_confirm_keyimage_sync(ctx: Context) -> None:
)
async def require_confirm_live_refresh(ctx: Context) -> None:
async def require_confirm_live_refresh() -> None:
await confirm_action(
ctx,
"live_refresh",
"Confirm refresh",
description="Do you really want to\nstart refresh?",
@ -90,14 +86,13 @@ async def require_confirm_live_refresh(ctx: Context) -> None:
)
async def require_confirm_tx_key(ctx: Context, export_key: bool = False) -> None:
async def require_confirm_tx_key(export_key: bool = False) -> None:
description = (
"Do you really want to export tx_key?"
if export_key
else "Do you really want to export tx_der\nfor tx_proof?"
)
await confirm_action(
ctx,
"export_tx_key",
"Confirm export",
description=description,
@ -106,7 +101,6 @@ async def require_confirm_tx_key(ctx: Context, export_key: bool = False) -> None
async def require_confirm_transaction(
ctx: Context,
state: State,
tsx_data: MoneroTransactionData,
network_type: MoneroNetworkType,
@ -122,7 +116,7 @@ async def require_confirm_transaction(
payment_id = tsx_data.payment_id # local_cache_attribute
if tsx_data.unlock_time != 0:
await _require_confirm_unlock_time(ctx, tsx_data.unlock_time)
await _require_confirm_unlock_time(tsx_data.unlock_time)
for idx, dst in enumerate(outputs):
is_change = change_idx is not None and idx == change_idx
@ -135,21 +129,20 @@ async def require_confirm_transaction(
cur_payment = payment_id
else:
cur_payment = None
await _require_confirm_output(ctx, dst, network_type, cur_payment)
await _require_confirm_output(dst, network_type, cur_payment)
if (
payment_id
and not tsx_data.integrated_indices
and payment_id != DUMMY_PAYMENT_ID
):
await _require_confirm_payment_id(ctx, payment_id)
await _require_confirm_payment_id(payment_id)
await _require_confirm_fee(ctx, tsx_data.fee)
await _require_confirm_fee(tsx_data.fee)
progress.step(state, 0)
async def _require_confirm_output(
ctx: Context,
dst: MoneroTransactionDestinationEntry,
network_type: MoneroNetworkType,
payment_id: bytes | None,
@ -167,18 +160,16 @@ async def _require_confirm_output(
)
await confirm_output(
ctx,
addr,
_format_amount(dst.amount),
br_code=BRT_SignTx,
)
async def _require_confirm_payment_id(ctx: Context, payment_id: bytes) -> None:
async def _require_confirm_payment_id(payment_id: bytes) -> None:
from trezor.ui.layouts import confirm_blob
await confirm_blob(
ctx,
"confirm_payment_id",
"Payment ID",
payment_id,
@ -186,9 +177,8 @@ async def _require_confirm_payment_id(ctx: Context, payment_id: bytes) -> None:
)
async def _require_confirm_fee(ctx: Context, fee: int) -> None:
async def _require_confirm_fee(fee: int) -> None:
await confirm_metadata(
ctx,
"confirm_final",
"Confirm fee",
"{}",
@ -197,9 +187,8 @@ async def _require_confirm_fee(ctx: Context, fee: int) -> None:
)
async def _require_confirm_unlock_time(ctx: Context, unlock_time: int) -> None:
async def _require_confirm_unlock_time(unlock_time: int) -> None:
await confirm_metadata(
ctx,
"confirm_locktime",
"Confirm unlock time",
"Unlock time for this transaction is set to {}",

@ -12,7 +12,6 @@ if TYPE_CHECKING:
MoneroLiveRefreshStartAck,
)
from trezor.ui.layouts.common import ProgressLayout
from trezor.wire import Context
from apps.common.keychain import Keychain
from .xmr.credentials import AccountCreds
@ -20,25 +19,26 @@ if TYPE_CHECKING:
@auto_keychain(__name__)
async def live_refresh(
ctx: Context, msg: MoneroLiveRefreshStartRequest, keychain: Keychain
msg: MoneroLiveRefreshStartRequest, keychain: Keychain
) -> MoneroLiveRefreshFinalAck:
import gc
from trezor.enums import MessageType
from trezor.messages import MoneroLiveRefreshFinalAck, MoneroLiveRefreshStepRequest
from trezor.wire.context import call_any
state = LiveRefreshState()
res = await _init_step(state, ctx, msg, keychain)
res = await _init_step(state, msg, keychain)
progress = layout.monero_live_refresh_progress()
while True:
step = await ctx.call_any(
step = await call_any(
res,
MessageType.MoneroLiveRefreshStepRequest,
MessageType.MoneroLiveRefreshFinalRequest,
)
del res
if MoneroLiveRefreshStepRequest.is_type_of(step):
res = _refresh_step(state, ctx, step, progress)
res = _refresh_step(state, step, progress)
else:
return MoneroLiveRefreshFinalAck()
gc.collect()
@ -52,7 +52,6 @@ class LiveRefreshState:
async def _init_step(
s: LiveRefreshState,
ctx: Context,
msg: MoneroLiveRefreshStartRequest,
keychain: Keychain,
) -> MoneroLiveRefreshStartAck:
@ -60,10 +59,10 @@ async def _init_step(
from apps.common import paths
from trezor.messages import MoneroLiveRefreshStartAck
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get(storage_cache.APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh(ctx)
await layout.require_confirm_live_refresh()
storage_cache.set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
@ -73,7 +72,6 @@ async def _init_step(
def _refresh_step(
s: LiveRefreshState,
ctx: Context,
msg: MoneroLiveRefreshStepRequest,
progress: ProgressLayout,
) -> MoneroLiveRefreshStepAck:

@ -7,18 +7,16 @@ if TYPE_CHECKING:
from trezor.messages import MoneroTransactionFinalAck
from apps.common.keychain import Keychain
from apps.monero.signing.state import State
from trezor.wire import Context
@auto_keychain(__name__)
async def sign_tx(
ctx: Context, received_msg, keychain: Keychain
) -> MoneroTransactionFinalAck:
async def sign_tx(received_msg, keychain: Keychain) -> MoneroTransactionFinalAck:
import gc
from trezor import log, utils
from trezor.wire.context import get_context
from apps.monero.signing.state import State
state = State(ctx)
state = State()
mods = utils.unimport_begin()
progress = MoneroTransactionProgress()
@ -36,11 +34,12 @@ async def sign_tx(
if accept_msgs is None:
break
ctx = get_context()
await ctx.write(result_msg)
del (result_msg, received_msg)
utils.unimport_end(mods)
received_msg = await ctx.read_any(accept_msgs)
received_msg = await ctx.read(accept_msgs)
utils.unimport_end(mods)
return result_msg

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.wire import Context
from apps.monero.xmr.crypto import Point, Scalar
from apps.monero.xmr.credentials import AccountCreds
from trezor.messages import MoneroTransactionDestinationEntry
@ -19,20 +18,16 @@ class State:
STEP_ALL_OUT = 500
STEP_SIGN = 600
def __init__(self, ctx: Context) -> None:
def __init__(self) -> None:
from apps.monero.xmr.keccak_hasher import KeccakXmrArchive
from apps.monero.xmr.mlsag_hasher import PreMlsagHasher
from apps.monero.xmr import crypto
self.ctx = ctx
"""
Account credentials
type: AccountCreds
- view private/public key
- spend private/public key
- and its corresponding address
"""
# Account credentials
# type: AccountCreds
# - view private/public key
# - spend private/public key
# - and its corresponding address
self.creds: AccountCreds | None = None
# HMAC/encryption keys used to protect offloaded data

@ -37,7 +37,7 @@ async def init_transaction(
mem_trace = state.mem_trace # local_cache_attribute
outputs = tsx_data.outputs # local_cache_attribute
await paths.validate_path(state.ctx, keychain, address_n)
await paths.validate_path(keychain, address_n)
state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0
@ -57,7 +57,6 @@ async def init_transaction(
# Ask for confirmation
await layout.require_confirm_transaction(
state.ctx,
state,
tsx_data,
state.creds.network_type,

@ -6,14 +6,11 @@ from . import CURVE, PATTERNS, SLIP44_ID
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.messages import NEMGetAddress, NEMAddress
@with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE)
async def get_address(
ctx: Context, msg: NEMGetAddress, keychain: Keychain
) -> NEMAddress:
async def get_address(msg: NEMGetAddress, keychain: Keychain) -> NEMAddress:
from trezor.messages import NEMAddress
from trezor.ui.layouts import show_address
from apps.common.paths import address_n_to_str, validate_path
@ -24,14 +21,13 @@ async def get_address(
network = msg.network # local_cache_attribute
validate_network(network)
await validate_path(ctx, keychain, address_n, check_path(address_n, network))
await validate_path(keychain, address_n, check_path(address_n, network))
node = keychain.derive(address_n)
address = node.nem_address(network)
if msg.show_display:
await show_address(
ctx,
address,
case_sensitive=False,
path=address_n_to_str(address_n),

@ -1,18 +1,12 @@
from typing import TYPE_CHECKING
from trezor.enums import ButtonRequestType
from trezor.strings import format_amount
from trezor.ui.layouts import confirm_metadata
from .helpers import NEM_MAX_DIVISIBILITY
if TYPE_CHECKING:
from trezor.wire import Context
async def require_confirm_text(ctx: Context, action: str) -> None:
async def require_confirm_text(action: str) -> None:
await confirm_metadata(
ctx,
"confirm_nem",
"Confirm action",
action,
@ -20,9 +14,8 @@ async def require_confirm_text(ctx: Context, action: str) -> None:
)
async def require_confirm_fee(ctx: Context, action: str, fee: int) -> None:
async def require_confirm_fee(action: str, fee: int) -> None:
await confirm_metadata(
ctx,
"confirm_fee",
"Confirm fee",
action + "\n{}",
@ -31,21 +24,19 @@ async def require_confirm_fee(ctx: Context, action: str, fee: int) -> None:
)
async def require_confirm_content(ctx: Context, headline: str, content: list) -> None:
async def require_confirm_content(headline: str, content: list) -> None:
from trezor.ui.layouts import confirm_properties
await confirm_properties(
ctx,
"confirm_content",
headline,
content,
)
async def require_confirm_final(ctx: Context, fee: int) -> None:
async def require_confirm_final(fee: int) -> None:
# we use SignTx, not ConfirmOutput, for compatibility with T1
await confirm_metadata(
ctx,
"confirm_final",
"Final confirm",
"Sign this transaction\n{}\nfor network fee?",

@ -8,24 +8,21 @@ if TYPE_CHECKING:
NEMMosaicSupplyChange,
NEMTransactionCommon,
)
from trezor.wire import Context
async def mosaic_creation(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
creation: NEMMosaicCreation,
) -> bytes:
await layout.ask_mosaic_creation(ctx, common, creation)
await layout.ask_mosaic_creation(common, creation)
return serialize.serialize_mosaic_creation(common, creation, public_key)
async def supply_change(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
change: NEMMosaicSupplyChange,
) -> bytes:
await layout.ask_supply_change(ctx, common, change)
await layout.ask_supply_change(common, change)
return serialize.serialize_mosaic_supply_change(common, change, public_key)

@ -9,11 +9,10 @@ if TYPE_CHECKING:
NEMMosaicSupplyChange,
NEMTransactionCommon,
)
from trezor.wire import Context
async def ask_mosaic_creation(
ctx: Context, common: NEMTransactionCommon, creation: NEMMosaicCreation
common: NEMTransactionCommon, creation: NEMMosaicCreation
) -> None:
from ..layout import require_confirm_fee
@ -21,15 +20,15 @@ async def ask_mosaic_creation(
("Create mosaic", creation.definition.mosaic),
("under namespace", creation.definition.namespace),
]
await require_confirm_content(ctx, "Create mosaic", creation_message)
await _require_confirm_properties(ctx, creation.definition)
await require_confirm_fee(ctx, "Confirm creation fee", creation.fee)
await require_confirm_content("Create mosaic", creation_message)
await _require_confirm_properties(creation.definition)
await require_confirm_fee("Confirm creation fee", creation.fee)
await require_confirm_final(ctx, common.fee)
await require_confirm_final(common.fee)
async def ask_supply_change(
ctx: Context, common: NEMTransactionCommon, change: NEMMosaicSupplyChange
common: NEMTransactionCommon, change: NEMMosaicSupplyChange
) -> None:
from trezor.enums import NEMSupplyChangeType
from ..layout import require_confirm_text
@ -38,21 +37,19 @@ async def ask_supply_change(
("Modify supply for", change.mosaic),
("under namespace", change.namespace),
]
await require_confirm_content(ctx, "Supply change", supply_message)
await require_confirm_content("Supply change", supply_message)
if change.type == NEMSupplyChangeType.SupplyChange_Decrease:
action = "Decrease"
elif change.type == NEMSupplyChangeType.SupplyChange_Increase:
action = "Increase"
else:
raise ValueError("Invalid supply change type")
await require_confirm_text(ctx, f"{action} supply by {change.delta} whole units?")
await require_confirm_text(f"{action} supply by {change.delta} whole units?")
await require_confirm_final(ctx, common.fee)
await require_confirm_final(common.fee)
async def _require_confirm_properties(
ctx: Context, definition: NEMMosaicDefinition
) -> None:
async def _require_confirm_properties(definition: NEMMosaicDefinition) -> None:
from trezor.enums import NEMMosaicLevy
from trezor.ui.layouts import confirm_properties
@ -97,7 +94,6 @@ async def _require_confirm_properties(
append(("Levy type:", levy_type))
await confirm_properties(
ctx,
"confirm_properties",
"Confirm properties",
properties,

@ -8,11 +8,10 @@ if TYPE_CHECKING:
NEMSignTx,
NEMTransactionCommon,
)
from trezor.wire import Context
async def ask(ctx: Context, msg: NEMSignTx) -> None:
await layout.ask_multisig(ctx, msg)
async def ask(msg: NEMSignTx) -> None:
await layout.ask_multisig(msg)
def initiate(public_key: bytes, common: NEMTransactionCommon, inner_tx: bytes) -> bytes:
@ -26,13 +25,12 @@ def cosign(
async def aggregate_modification(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
aggr: NEMAggregateModification,
multisig: bool,
) -> bytes:
await layout.ask_aggregate_modification(ctx, common, aggr, multisig)
await layout.ask_aggregate_modification(common, aggr, multisig)
w = serialize.serialize_aggregate_modification(common, aggr, public_key)
for m in aggr.modifications:

@ -8,24 +8,22 @@ if TYPE_CHECKING:
NEMSignTx,
NEMTransactionCommon,
)
from trezor.wire import Context
async def ask_multisig(ctx: Context, msg: NEMSignTx) -> None:
async def ask_multisig(msg: NEMSignTx) -> None:
from ..layout import require_confirm_fee
assert msg.multisig is not None # sign_tx
assert msg.multisig.signer is not None # sign_tx
address = nem.compute_address(msg.multisig.signer, msg.transaction.network)
if msg.cosigning:
await _require_confirm_address(ctx, "Cosign transaction for", address)
await _require_confirm_address("Cosign transaction for", address)
else:
await _require_confirm_address(ctx, "Initiate transaction for", address)
await require_confirm_fee(ctx, "Confirm multisig fee", msg.transaction.fee)
await _require_confirm_address("Initiate transaction for", address)
await require_confirm_fee("Confirm multisig fee", msg.transaction.fee)
async def ask_aggregate_modification(
ctx: Context,
common: NEMTransactionCommon,
mod: NEMAggregateModification,
multisig: bool,
@ -34,7 +32,7 @@ async def ask_aggregate_modification(
from ..layout import require_confirm_final, require_confirm_text
if not multisig:
await require_confirm_text(ctx, "Convert account to multisig account?")
await require_confirm_text("Convert account to multisig account?")
for m in mod.modifications:
if m.type == NEMModificationType.CosignatoryModification_Add:
@ -42,24 +40,23 @@ async def ask_aggregate_modification(
else:
action = "Remove"
address = nem.compute_address(m.public_key, common.network)
await _require_confirm_address(ctx, action + " cosignatory", address)
await _require_confirm_address(action + " cosignatory", address)
if mod.relative_change:
if multisig:
action = "Modify the number of cosignatories by "
else:
action = "Set minimum cosignatories to "
await require_confirm_text(ctx, action + str(mod.relative_change) + "?")
await require_confirm_text(action + str(mod.relative_change) + "?")
await require_confirm_final(ctx, common.fee)
await require_confirm_final(common.fee)
async def _require_confirm_address(ctx: Context, action: str, address: str) -> None:
async def _require_confirm_address(action: str, address: str) -> None:
from trezor.enums import ButtonRequestType
from trezor.ui.layouts import confirm_address
await confirm_address(
ctx,
"Confirm address",
address,
action,

@ -2,16 +2,14 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from trezor.wire import Context
async def namespace(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
namespace: NEMProvisionNamespace,
) -> bytes:
from . import layout, serialize
await layout.ask_provision_namespace(ctx, common, namespace)
await layout.ask_provision_namespace(common, namespace)
return serialize.serialize_provision_namespace(common, namespace, public_key)

@ -2,11 +2,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import NEMProvisionNamespace, NEMTransactionCommon
from trezor.wire import Context
async def ask_provision_namespace(
ctx: Context, common: NEMTransactionCommon, namespace: NEMProvisionNamespace
common: NEMTransactionCommon, namespace: NEMProvisionNamespace
) -> None:
from ..layout import (
require_confirm_content,
@ -19,11 +18,11 @@ async def ask_provision_namespace(
("Create namespace", namespace.namespace),
("under namespace", namespace.parent),
]
await require_confirm_content(ctx, "Confirm namespace", content)
await require_confirm_content("Confirm namespace", content)
else:
content = [("Create namespace", namespace.namespace)]
await require_confirm_content(ctx, "Confirm namespace", content)
await require_confirm_content("Confirm namespace", content)
await require_confirm_fee(ctx, "Confirm rental fee", namespace.fee)
await require_confirm_fee("Confirm rental fee", namespace.fee)
await require_confirm_final(ctx, common.fee)
await require_confirm_final(common.fee)

@ -7,11 +7,10 @@ from . import CURVE, PATTERNS, SLIP44_ID
if TYPE_CHECKING:
from trezor.messages import NEMSignTx, NEMSignedTx
from apps.common.keychain import Keychain
from trezor.wire import Context
@with_slip44_keychain(*PATTERNS, slip44_id=SLIP44_ID, curve=CURVE)
async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSignedTx:
async def sign_tx(msg: NEMSignTx, keychain: Keychain) -> NEMSignedTx:
from trezor.wire import DataError
from trezor.crypto.curve import ed25519
from trezor.messages import NEMSignedTx
@ -28,7 +27,6 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
transaction = msg.transaction # local_cache_attribute
await validate_path(
ctx,
keychain,
transaction.address_n,
check_path(transaction.address_n, transaction.network),
@ -41,22 +39,21 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
raise DataError("No signer provided")
public_key = msg_multisig.signer
common = msg_multisig
await multisig.ask(ctx, msg)
await multisig.ask(msg)
else:
public_key = seed.remove_ed25519_prefix(node.public_key())
common = transaction
if msg.transfer:
tx = await transfer.transfer(ctx, public_key, common, msg.transfer, node)
tx = await transfer.transfer(public_key, common, msg.transfer, node)
elif msg.provision_namespace:
tx = await namespace.namespace(ctx, public_key, common, msg.provision_namespace)
tx = await namespace.namespace(public_key, common, msg.provision_namespace)
elif msg.mosaic_creation:
tx = await mosaic.mosaic_creation(ctx, public_key, common, msg.mosaic_creation)
tx = await mosaic.mosaic_creation(public_key, common, msg.mosaic_creation)
elif msg.supply_change:
tx = await mosaic.supply_change(ctx, public_key, common, msg.supply_change)
tx = await mosaic.supply_change(public_key, common, msg.supply_change)
elif msg.aggregate_modification:
tx = await multisig.aggregate_modification(
ctx,
public_key,
common,
msg.aggregate_modification,
@ -64,7 +61,7 @@ async def sign_tx(ctx: Context, msg: NEMSignTx, keychain: Keychain) -> NEMSigned
)
elif msg.importance_transfer:
tx = await transfer.importance_transfer(
ctx, public_key, common, msg.importance_transfer
public_key, common, msg.importance_transfer
)
else:
raise DataError("No transaction provided")

@ -4,12 +4,10 @@ from . import layout, serialize
if TYPE_CHECKING:
from trezor.messages import NEMImportanceTransfer, NEMTransactionCommon, NEMTransfer
from trezor.wire import Context
from trezor.crypto import bip32
async def transfer(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
transfer: NEMTransfer,
@ -18,7 +16,7 @@ async def transfer(
transfer.mosaics = serialize.canonicalize_mosaics(transfer.mosaics)
payload, encrypted = serialize.get_transfer_payload(transfer, node)
await layout.ask_transfer(ctx, common, transfer, encrypted)
await layout.ask_transfer(common, transfer, encrypted)
w = serialize.serialize_transfer(common, transfer, public_key, payload, encrypted)
for mosaic in transfer.mosaics:
@ -27,10 +25,9 @@ async def transfer(
async def importance_transfer(
ctx: Context,
public_key: bytes,
common: NEMTransactionCommon,
imp: NEMImportanceTransfer,
) -> bytes:
await layout.ask_importance_transfer(ctx, common, imp)
await layout.ask_importance_transfer(common, imp)
return serialize.serialize_importance_transfer(common, imp, public_key)

@ -14,11 +14,9 @@ if TYPE_CHECKING:
NEMTransactionCommon,
NEMTransfer,
)
from trezor.wire import Context
async def ask_transfer(
ctx: Context,
common: NEMTransactionCommon,
transfer: NEMTransfer,
encrypted: bool,
@ -29,7 +27,6 @@ async def ask_transfer(
if transfer.payload:
# require_confirm_payload
await confirm_text(
ctx,
"confirm_payload",
"Confirm payload",
bytes(transfer.payload).decode(),
@ -38,20 +35,19 @@ async def ask_transfer(
)
for mosaic in transfer.mosaics:
await _ask_transfer_mosaic(ctx, common, transfer, mosaic)
await _ask_transfer_mosaic(common, transfer, mosaic)
# require_confirm_transfer
await confirm_output(
ctx,
transfer.recipient,
f"{format_amount(_get_xem_amount(transfer), NEM_MAX_DIVISIBILITY)} XEM",
)
await require_confirm_final(ctx, common.fee)
await require_confirm_final(common.fee)
async def _ask_transfer_mosaic(
ctx: Context, common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic
common: NEMTransactionCommon, transfer: NEMTransfer, mosaic: NEMMosaic
) -> None:
from trezor.enums import NEMMosaicLevy
from trezor.ui.layouts import confirm_action, confirm_properties
@ -66,7 +62,6 @@ async def _ask_transfer_mosaic(
if definition:
await confirm_properties(
ctx,
"confirm_mosaic",
"Confirm mosaic",
(
@ -93,7 +88,6 @@ async def _ask_transfer_mosaic(
)
await confirm_properties(
ctx,
"confirm_mosaic_levy",
"Confirm mosaic",
(("Confirm mosaic\nlevy fee of", levy_msg),),
@ -101,7 +95,6 @@ async def _ask_transfer_mosaic(
else:
await confirm_action(
ctx,
"confirm_mosaic_unknown",
"Confirm mosaic",
"Unknown mosaic!",
@ -110,7 +103,6 @@ async def _ask_transfer_mosaic(
)
await confirm_properties(
ctx,
"confirm_mosaic_transfer",
"Confirm mosaic",
(
@ -133,7 +125,7 @@ def _get_xem_amount(transfer: NEMTransfer) -> int:
async def ask_importance_transfer(
ctx: Context, common: NEMTransactionCommon, imp: NEMImportanceTransfer
common: NEMTransactionCommon, imp: NEMImportanceTransfer
) -> None:
from trezor.enums import NEMImportanceTransferMode
from ..layout import require_confirm_text
@ -142,5 +134,5 @@ async def ask_importance_transfer(
m = "Activate"
else:
m = "Deactivate"
await require_confirm_text(ctx, m + " remote harvesting?")
await require_confirm_final(ctx, common.fee)
await require_confirm_text(m + " remote harvesting?")
await require_confirm_final(common.fee)

@ -5,26 +5,23 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import RippleGetAddress, RippleAddress
from apps.common.keychain import Keychain
from trezor.wire import Context
@auto_keychain(__name__)
async def get_address(
ctx: Context, msg: RippleGetAddress, keychain: Keychain
) -> RippleAddress:
async def get_address(msg: RippleGetAddress, keychain: Keychain) -> RippleAddress:
# NOTE: local imports here saves 20 bytes
from trezor.messages import RippleAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from .helpers import address_from_public_key
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = node.public_key()
address = address_from_public_key(pubkey)
if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(msg.address_n))
await show_address(address, path=paths.address_n_to_str(msg.address_n))
return RippleAddress(address=address)

@ -1,26 +1,19 @@
from typing import TYPE_CHECKING
from trezor.enums import ButtonRequestType
from trezor.strings import format_amount
from trezor.ui.layouts import confirm_metadata, confirm_total
from .helpers import DECIMALS
if TYPE_CHECKING:
from trezor.wire import Context
async def require_confirm_total(ctx: Context, total: int, fee: int) -> None:
async def require_confirm_total(total: int, fee: int) -> None:
await confirm_total(
ctx,
format_amount(total, DECIMALS) + " XRP",
format_amount(fee, DECIMALS) + " XRP",
)
async def require_confirm_destination_tag(ctx: Context, tag: int) -> None:
async def require_confirm_destination_tag(tag: int) -> None:
await confirm_metadata(
ctx,
"confirm_destination_tag",
"Confirm tag",
"Destination tag:\n{}",
@ -29,7 +22,7 @@ async def require_confirm_destination_tag(ctx: Context, tag: int) -> None:
)
async def require_confirm_tx(ctx: Context, to: str, value: int) -> None:
async def require_confirm_tx(to: str, value: int) -> None:
from trezor.ui.layouts import confirm_output
await confirm_output(ctx, to, format_amount(value, DECIMALS) + " XRP")
await confirm_output(to, format_amount(value, DECIMALS) + " XRP")

@ -5,14 +5,11 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import RippleSignTx, RippleSignedTx
from apps.common.keychain import Keychain
from trezor.wire import Context
# NOTE: it is one big function because that way it is the most flash-space-efficient
@auto_keychain(__name__)
async def sign_tx(
ctx: Context, msg: RippleSignTx, keychain: Keychain
) -> RippleSignedTx:
async def sign_tx(msg: RippleSignTx, keychain: Keychain) -> RippleSignedTx:
from trezor.crypto import der
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha512
@ -26,7 +23,7 @@ async def sign_tx(
if payment.amount > helpers.MAX_ALLOWED_AMOUNT:
raise ProcessError("Amount exceeds maximum allowed amount.")
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key())
@ -46,9 +43,9 @@ async def sign_tx(
raise ProcessError("Fee must be in the range of 10 to 10,000 drops")
if payment.destination_tag is not None:
await layout.require_confirm_destination_tag(ctx, payment.destination_tag)
await layout.require_confirm_tx(ctx, payment.destination, payment.amount)
await layout.require_confirm_total(ctx, payment.amount + msg.fee, msg.fee)
await layout.require_confirm_destination_tag(payment.destination_tag)
await layout.require_confirm_tx(payment.destination, payment.amount)
await layout.require_confirm_total(payment.amount + msg.fee, msg.fee)
# Signs and encodes signature into DER format
first_half_of_sha512 = sha512(to_sign).digest()[:32]

@ -4,20 +4,17 @@ from apps.common.keychain import auto_keychain
if TYPE_CHECKING:
from trezor.messages import StellarGetAddress, StellarAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
@auto_keychain(__name__)
async def get_address(
ctx: Context, msg: StellarGetAddress, keychain: Keychain
) -> StellarAddress:
async def get_address(msg: StellarGetAddress, keychain: Keychain) -> StellarAddress:
from apps.common import paths, seed
from trezor.messages import StellarAddress
from trezor.ui.layouts import show_address
from . import helpers
await paths.validate_path(ctx, keychain, msg.address_n)
await paths.validate_path(keychain, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = seed.remove_ed25519_prefix(node.public_key())
@ -25,6 +22,6 @@ async def get_address(
if msg.show_display:
path = paths.address_n_to_str(msg.address_n)
await show_address(ctx, address, case_sensitive=False, path=path)
await show_address(address, case_sensitive=False, path=path)
return StellarAddress(address=address)

@ -7,21 +7,18 @@ from trezor.enums import ButtonRequestType
from . import consts
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.enums import StellarMemoType
from trezor.messages import StellarAsset
async def require_confirm_init(
ctx: Context,
address: str,
network_passphrase: str,
accounts_match: bool,
) -> None:
description = "Initialize signing with" + " your account" if accounts_match else ""
await layouts.confirm_address(
ctx,
"Confirm Stellar",
address,
description,
@ -38,7 +35,6 @@ async def require_confirm_init(
if network:
await layouts.confirm_metadata(
ctx,
"confirm_init_network",
"Confirm network",
"Transaction is on {}",
@ -47,9 +43,8 @@ async def require_confirm_init(
)
async def require_confirm_timebounds(ctx: Context, start: int, end: int) -> None:
async def require_confirm_timebounds(start: int, end: int) -> None:
await layouts.confirm_properties(
ctx,
"confirm_timebounds",
"Confirm timebounds",
(
@ -65,9 +60,7 @@ async def require_confirm_timebounds(ctx: Context, start: int, end: int) -> None
)
async def require_confirm_memo(
ctx: Context, memo_type: StellarMemoType, memo_text: str
) -> None:
async def require_confirm_memo(memo_type: StellarMemoType, memo_text: str) -> None:
from trezor.enums import StellarMemoType
if memo_type == StellarMemoType.TEXT:
@ -80,7 +73,6 @@ async def require_confirm_memo(
description = "Memo (RETURN)"
else:
return await layouts.confirm_action(
ctx,
"confirm_memo",
"Confirm memo",
"No memo set!",
@ -89,7 +81,6 @@ async def require_confirm_memo(
)
await layouts.confirm_blob(
ctx,
"confirm_memo",
"Confirm memo",
memo_text,
@ -97,10 +88,9 @@ async def require_confirm_memo(
)
async def require_confirm_final(ctx: Context, fee: int, num_operations: int) -> None:
async def require_confirm_final(fee: int, num_operations: int) -> None:
op_str = strings.format_plural("{count} {plural}", num_operations, "operation")
await layouts.confirm_metadata(
ctx,
"confirm_final",
"Final confirm",
"Sign this transaction made up of " + op_str + " and pay {}\nfor fee?",

@ -2,11 +2,10 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.utils import Writer
from trezor.wire import Context
from consts import StellarMessageType
async def process_operation(ctx: Context, w: Writer, op: StellarMessageType) -> None:
async def process_operation(w: Writer, op: StellarMessageType) -> None:
# Importing the stuff inside (only) function saves around 100 bytes here
# (probably because the local lookup is more efficient than a global lookup)
@ -17,48 +16,48 @@ async def process_operation(ctx: Context, w: Writer, op: StellarMessageType) ->
from . import layout, serialize
if op.source_account:
await layout.confirm_source_account(ctx, op.source_account)
await layout.confirm_source_account(op.source_account)
serialize.write_account(w, op.source_account)
writers.write_uint32(w, consts.get_op_code(op))
# NOTE: each branch below has 45 bytes (26 the actions, 19 the condition)
if messages.StellarAccountMergeOp.is_type_of(op):
await layout.confirm_account_merge_op(ctx, op)
await layout.confirm_account_merge_op(op)
serialize.write_account_merge_op(w, op)
elif messages.StellarAllowTrustOp.is_type_of(op):
await layout.confirm_allow_trust_op(ctx, op)
await layout.confirm_allow_trust_op(op)
serialize.write_allow_trust_op(w, op)
elif messages.StellarBumpSequenceOp.is_type_of(op):
await layout.confirm_bump_sequence_op(ctx, op)
await layout.confirm_bump_sequence_op(op)
serialize.write_bump_sequence_op(w, op)
elif messages.StellarChangeTrustOp.is_type_of(op):
await layout.confirm_change_trust_op(ctx, op)
await layout.confirm_change_trust_op(op)
serialize.write_change_trust_op(w, op)
elif messages.StellarCreateAccountOp.is_type_of(op):
await layout.confirm_create_account_op(ctx, op)
await layout.confirm_create_account_op(op)
serialize.write_create_account_op(w, op)
elif messages.StellarCreatePassiveSellOfferOp.is_type_of(op):
await layout.confirm_create_passive_sell_offer_op(ctx, op)
await layout.confirm_create_passive_sell_offer_op(op)
serialize.write_create_passive_sell_offer_op(w, op)
elif messages.StellarManageDataOp.is_type_of(op):
await layout.confirm_manage_data_op(ctx, op)
await layout.confirm_manage_data_op(op)
serialize.write_manage_data_op(w, op)
elif messages.StellarManageBuyOfferOp.is_type_of(op):
await layout.confirm_manage_buy_offer_op(ctx, op)
await layout.confirm_manage_buy_offer_op(op)
serialize.write_manage_buy_offer_op(w, op)
elif messages.StellarManageSellOfferOp.is_type_of(op):
await layout.confirm_manage_sell_offer_op(ctx, op)
await layout.confirm_manage_sell_offer_op(op)
serialize.write_manage_sell_offer_op(w, op)
elif messages.StellarPathPaymentStrictReceiveOp.is_type_of(op):
await layout.confirm_path_payment_strict_receive_op(ctx, op)
await layout.confirm_path_payment_strict_receive_op(op)
serialize.write_path_payment_strict_receive_op(w, op)
elif messages.StellarPathPaymentStrictSendOp.is_type_of(op):
await layout.confirm_path_payment_strict_send_op(ctx, op)
await layout.confirm_path_payment_strict_send_op(op)
serialize.write_path_payment_strict_send_op(w, op)
elif messages.StellarPaymentOp.is_type_of(op):
await layout.confirm_payment_op(ctx, op)
await layout.confirm_payment_op(op)
serialize.write_payment_op(w, op)
elif messages.StellarSetOptionsOp.is_type_of(op):
await layout.confirm_set_options_op(ctx, op)
await layout.confirm_set_options_op(op)
serialize.write_set_options_op(w, op)
else:
raise ValueError("Unknown operation")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save