diff --git a/core/src/apps/bitcoin/sign_tx/__init__.py b/core/src/apps/bitcoin/sign_tx/__init__.py index 4bd4b0b79c..a2a24377f6 100644 --- a/core/src/apps/bitcoin/sign_tx/__init__.py +++ b/core/src/apps/bitcoin/sign_tx/__init__.py @@ -4,11 +4,11 @@ from trezor.messages.SignTx import SignTx from trezor.messages.TxAck import TxAck from trezor.messages.TxRequest import TxRequest -from apps.common import coininfo, paths +from apps.common import coininfo from ..common import BITCOIN_NAMES from ..keychain import with_keychain -from . import approvers, bitcoin, helpers, layout, progress +from . import approvers, bitcoin, helpers, progress if not utils.BITCOIN_ONLY: from . import bitcoinlike, decred, zcash @@ -52,32 +52,8 @@ async def sign_tx( if req.request_type == TXFINISHED: break res = await ctx.call(req, TxAck, field_cache) - elif isinstance(req, helpers.UiConfirmOutput): - res = await layout.confirm_output(ctx, req.output, req.coin) - progress.report_init() - elif isinstance(req, helpers.UiConfirmTotal): - res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin) - progress.report_init() - elif isinstance(req, helpers.UiConfirmJointTotal): - res = await layout.confirm_joint_total( - ctx, req.spending, req.total, req.coin - ) - progress.report_init() - elif isinstance(req, helpers.UiConfirmFeeOverThreshold): - res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) - progress.report_init() - elif isinstance(req, helpers.UiConfirmChangeCountOverThreshold): - res = await layout.confirm_change_count_over_threshold( - ctx, req.change_count - ) - progress.report_init() - elif isinstance(req, helpers.UiConfirmNonDefaultLocktime): - res = await layout.confirm_nondefault_locktime( - ctx, req.lock_time, req.lock_time_disabled - ) - progress.report_init() - elif isinstance(req, helpers.UiConfirmForeignAddress): - res = await paths.show_path_warning(ctx, req.address_n) + elif isinstance(req, helpers.UiConfirm): + res = await req.confirm_dialog(ctx) progress.report_init() else: raise TypeError("Invalid signing instruction") diff --git a/core/src/apps/bitcoin/sign_tx/helpers.py b/core/src/apps/bitcoin/sign_tx/helpers.py index 124b620fb9..b0359019f7 100644 --- a/core/src/apps/bitcoin/sign_tx/helpers.py +++ b/core/src/apps/bitcoin/sign_tx/helpers.py @@ -14,10 +14,12 @@ from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxRequest import TxRequest +from apps.common import paths from apps.common.coininfo import CoinInfo from .. import common from ..writers import TX_HASH_SIZE +from . import layout if False: from typing import Any, Awaitable @@ -27,59 +29,87 @@ if False: # === -class UiConfirmOutput: +class UiConfirm: + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + raise NotImplementedError + + +class UiConfirmOutput(UiConfirm): def __init__(self, output: TxOutputType, coin: CoinInfo): self.output = output self.coin = coin + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_output(ctx, self.output, self.coin) + __eq__ = utils.obj_eq -class UiConfirmTotal: +class UiConfirmTotal(UiConfirm): def __init__(self, spending: int, fee: int, coin: CoinInfo): self.spending = spending self.fee = fee self.coin = coin + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_total(ctx, self.spending, self.fee, self.coin) + __eq__ = utils.obj_eq -class UiConfirmJointTotal: +class UiConfirmJointTotal(UiConfirm): def __init__(self, spending: int, total: int, coin: CoinInfo): self.spending = spending self.total = total self.coin = coin + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_joint_total(ctx, self.spending, self.total, self.coin) + __eq__ = utils.obj_eq -class UiConfirmFeeOverThreshold: +class UiConfirmFeeOverThreshold(UiConfirm): def __init__(self, fee: int, coin: CoinInfo): self.fee = fee self.coin = coin + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_feeoverthreshold(ctx, self.fee, self.coin) + __eq__ = utils.obj_eq -class UiConfirmChangeCountOverThreshold: +class UiConfirmChangeCountOverThreshold(UiConfirm): def __init__(self, change_count: int): self.change_count = change_count + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_change_count_over_threshold(ctx, self.change_count) + __eq__ = utils.obj_eq -class UiConfirmForeignAddress: +class UiConfirmForeignAddress(UiConfirm): def __init__(self, address_n: list): self.address_n = address_n + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return paths.show_path_warning(ctx, self.address_n) + __eq__ = utils.obj_eq -class UiConfirmNonDefaultLocktime: - def __init__(self, lock_time: int, lock_time_disabled): +class UiConfirmNonDefaultLocktime(UiConfirm): + def __init__(self, lock_time: int, lock_time_disabled: bool): self.lock_time = lock_time self.lock_time_disabled = lock_time_disabled + def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: + return layout.confirm_nondefault_locktime( + ctx, self.lock_time, self.lock_time_disabled + ) + __eq__ = utils.obj_eq