1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 14:28:07 +00:00

refactor(core): Replace Bitcoin signing progress functions with a class.

This commit is contained in:
Andrew Kozlik 2022-10-11 16:37:20 +02:00 committed by Andrew Kozlik
parent a1a34774b8
commit 5dddb06e2b
4 changed files with 116 additions and 111 deletions

View File

@ -91,6 +91,6 @@ async def sign_tx(
res = await ctx.call(req, request_class) res = await ctx.call(req, request_class)
elif isinstance(req, helpers.UiConfirm): elif isinstance(req, helpers.UiConfirm):
res = await req.confirm_dialog(ctx) res = await req.confirm_dialog(ctx)
progress.report_init() progress.progress.report_init()
else: else:
raise TypeError("Invalid signing instruction") raise TypeError("Invalid signing instruction")

View File

@ -19,7 +19,8 @@ from ..common import (
) )
from ..ownership import verify_nonownership from ..ownership import verify_nonownership
from ..verification import SignatureVerifier from ..verification import SignatureVerifier
from . import approvers, helpers, progress from . import approvers, helpers
from .progress import progress
from .sig_hasher import BitcoinSigHasher from .sig_hasher import BitcoinSigHasher
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo

View File

@ -12,9 +12,10 @@ from apps.common.writers import write_compact_size
from .. import multisig, scripts_decred, writers from .. import multisig, scripts_decred, writers
from ..common import SigHashType, ecdsa_hash_pubkey, ecdsa_sign from ..common import SigHashType, ecdsa_hash_pubkey, ecdsa_sign
from . import approvers, helpers, progress from . import approvers, helpers
from .approvers import BasicApprover from .approvers import BasicApprover
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
from .progress import progress
DECRED_SERIALIZE_FULL = const(0 << 16) DECRED_SERIALIZE_FULL = const(0 << 16)
DECRED_SERIALIZE_NO_WITNESS = const(1 << 16) DECRED_SERIALIZE_NO_WITNESS = const(1 << 16)

View File

@ -12,25 +12,32 @@ if TYPE_CHECKING:
# the input, prevtx metadata, prevtx input, prevtx output, prevtx change-output # the input, prevtx metadata, prevtx input, prevtx output, prevtx change-output
_PREV_TX_MULTIPLIER = 5 _PREV_TX_MULTIPLIER = 5
_progress = 0
_steps = 0
_signing = False
_prev_tx_step = 0
class Progress:
def __init__(self):
self.progress = 0
self.steps = 0
self.signing = False
def init(tx: SignTx) -> None: # We don't know how long it will take to fetch the previous transactions,
global _progress, _steps, _signing # so for each one we reserve _PREV_TX_MULTIPLIER steps in the signing
_progress = 0 # progress. Once we fetch a prev_tx's metadata, we subdivide the reserved
_signing = False # space and then prev_tx_step represents the progress of fetching one
# prev_tx input or output in the overall signing progress.
self.prev_tx_step = 0
def init(self, tx: SignTx) -> None:
self.progress = 0
self.signing = False
# Step 1 and 2 - load inputs and outputs # Step 1 and 2 - load inputs and outputs
_steps = tx.inputs_count + tx.outputs_count self.steps = tx.inputs_count + tx.outputs_count
report_init()
report()
self.report_init()
self.report()
def init_signing( def init_signing(
self,
external: int, external: int,
segwit: int, segwit: int,
taproot_only: bool, taproot_only: bool,
@ -41,26 +48,27 @@ def init_signing(
orig_txs: list[OriginalTxInfo], orig_txs: list[OriginalTxInfo],
) -> None: ) -> None:
if __debug__: if __debug__:
assert_finished() self.assert_finished()
global _progress, _steps, _signing self.progress = 0
_progress = 0 self.steps = 0
_steps = 0 self.signing = True
_signing = True
# Step 3 - verify inputs # Step 3 - verify inputs
if taproot_only or (coin.overwintered and tx.version == 5): if taproot_only or (coin.overwintered and tx.version == 5):
if has_presigned: if has_presigned:
_steps += external self.steps += external
else: else:
_steps = tx.inputs_count * _PREV_TX_MULTIPLIER self.steps = tx.inputs_count * _PREV_TX_MULTIPLIER
for orig in orig_txs: for orig in orig_txs:
_steps += orig.tx.inputs_count self.steps += orig.tx.inputs_count
# Steps 3 and 4 - get_legacy_tx_digest() for each legacy input. # Steps 3 and 4 - get_legacy_tx_digest() for each legacy input.
if not (coin.force_bip143 or coin.overwintered or coin.decred): if not (coin.force_bip143 or coin.overwintered or coin.decred):
_steps += (tx.inputs_count - segwit) * (tx.inputs_count + tx.outputs_count) self.steps += (tx.inputs_count - segwit) * (
tx.inputs_count + tx.outputs_count
)
if segwit != tx.inputs_count: if segwit != tx.inputs_count:
# The transaction has a legacy input. # The transaction has a legacy input.
@ -68,67 +76,62 @@ def init_signing(
# Simplification: We assume that all original transaction inputs # Simplification: We assume that all original transaction inputs
# are legacy, since mixed script types are not supported in Suite. # are legacy, since mixed script types are not supported in Suite.
for orig in orig_txs: for orig in orig_txs:
_steps += orig.tx.inputs_count * ( self.steps += orig.tx.inputs_count * (
orig.tx.inputs_count + orig.tx.outputs_count orig.tx.inputs_count + orig.tx.outputs_count
) )
# Steps 4 and 6 - serialize and sign inputs # Steps 4 and 6 - serialize and sign inputs
if serialize: if serialize:
_steps += tx.inputs_count + segwit self.steps += tx.inputs_count + segwit
else: else:
_steps += tx.inputs_count - external self.steps += tx.inputs_count - external
# Step 5 - serialize outputs # Step 5 - serialize outputs
if serialize and not coin.decred: if serialize and not coin.decred:
_steps += tx.outputs_count self.steps += tx.outputs_count
report_init() self.report_init()
report() self.report()
def init_prev_tx(self, inputs: int, outputs: int) -> None:
self.prev_tx_step = _PREV_TX_MULTIPLIER / (inputs + outputs)
def init_prev_tx(inputs: int, outputs: int) -> None: def advance(self) -> None:
global _prev_tx_step self.progress += 1
_prev_tx_step = _PREV_TX_MULTIPLIER / (inputs + outputs) self.report()
def advance_prev_tx(self) -> None:
self.progress += self.prev_tx_step
self.report()
def advance() -> None: def report_init(self) -> None:
global _progress
_progress += 1
report()
def advance_prev_tx() -> None:
global _progress
_progress += _prev_tx_step
report()
def report_init() -> None:
from trezor import workflow from trezor import workflow
workflow.close_others() workflow.close_others()
ui.display.clear() ui.display.clear()
if _signing: if self.signing:
ui.header("Signing transaction") ui.header("Signing transaction")
else: else:
ui.header("Loading transaction") ui.header("Loading transaction")
def report(self) -> None:
def report() -> None:
from trezor import utils from trezor import utils
if utils.DISABLE_ANIMATION: if utils.DISABLE_ANIMATION:
return return
p = int(1000 * _progress / _steps) p = int(1000 * self.progress / self.steps)
ui.display.loader(p, False, 18, ui.WHITE, ui.BG) ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
if __debug__: if __debug__:
def assert_finished() -> None: def assert_finished(self) -> None:
if abs(_progress - _steps) > 0.5: if abs(self.progress - self.steps) > 0.5:
operation = "signing" if _signing else "loading"
from trezor import wire from trezor import wire
operation = "signing" if self.signing else "loading"
raise wire.FirmwareError( raise wire.FirmwareError(
f"Transaction {operation} progress finished at {_progress}/{_steps}." f"Transaction {operation} progress finished at {self.progress}/{self.steps}."
) )
progress = Progress()