From c2f5174b434796c745ae362863c643b6af048380 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Tue, 21 May 2019 09:48:55 +0200 Subject: [PATCH] core: optimize memory usage in signing --- core/src/apps/wallet/sign_tx/__init__.py | 12 +++++- core/src/apps/wallet/sign_tx/helpers.py | 7 ++++ core/src/apps/wallet/sign_tx/layout.py | 28 +++++++++---- core/src/apps/wallet/sign_tx/scripts.py | 51 ++++++++++++++++++------ core/src/apps/wallet/sign_tx/signing.py | 9 ++++- 5 files changed, 83 insertions(+), 24 deletions(-) diff --git a/core/src/apps/wallet/sign_tx/__init__.py b/core/src/apps/wallet/sign_tx/__init__.py index 5e468298d..02c85814f 100644 --- a/core/src/apps/wallet/sign_tx/__init__.py +++ b/core/src/apps/wallet/sign_tx/__init__.py @@ -1,4 +1,4 @@ -from trezor import wire +from trezor import utils, wire from trezor.messages.MessageType import TxAck from trezor.messages.RequestType import TXFINISHED from trezor.messages.TxRequest import TxRequest @@ -38,19 +38,29 @@ async def sign_tx(ctx, msg, keychain): break res = await ctx.call(req, TxAck) elif isinstance(req, helpers.UiConfirmOutput): + mods = utils.unimport_begin() res = await layout.confirm_output(ctx, req.output, req.coin) + utils.unimport_end(mods) progress.report_init() elif isinstance(req, helpers.UiConfirmTotal): + mods = utils.unimport_begin() res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin) + utils.unimport_end(mods) progress.report_init() elif isinstance(req, helpers.UiConfirmFeeOverThreshold): + mods = utils.unimport_begin() res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) + utils.unimport_end(mods) progress.report_init() elif isinstance(req, helpers.UiConfirmNonDefaultLocktime): + mods = utils.unimport_begin() res = await layout.confirm_nondefault_locktime(ctx, req.lock_time) + utils.unimport_end(mods) progress.report_init() elif isinstance(req, helpers.UiConfirmForeignAddress): + mods = utils.unimport_begin() res = await paths.show_path_warning(ctx, req.address_n) + utils.unimport_end(mods) progress.report_init() else: raise TypeError("Invalid signing instruction") diff --git a/core/src/apps/wallet/sign_tx/helpers.py b/core/src/apps/wallet/sign_tx/helpers.py index 07b9da300..8621c2564 100644 --- a/core/src/apps/wallet/sign_tx/helpers.py +++ b/core/src/apps/wallet/sign_tx/helpers.py @@ -1,3 +1,5 @@ +import gc + from trezor.messages import InputScriptType from trezor.messages.RequestType import ( TXEXTRADATA, @@ -85,6 +87,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes = None): tx_req.details.request_index = None ack = yield tx_req tx_req.serialized = None + gc.collect() return sanitize_tx_meta(ack.tx) @@ -100,6 +103,7 @@ def request_tx_extra_data( tx_req.serialized = None tx_req.details.extra_data_offset = None tx_req.details.extra_data_len = None + gc.collect() return ack.tx.extra_data @@ -109,6 +113,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes = None): tx_req.details.tx_hash = tx_hash ack = yield tx_req tx_req.serialized = None + gc.collect() return sanitize_tx_input(ack.tx) @@ -118,6 +123,7 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes = None): tx_req.details.tx_hash = tx_hash ack = yield tx_req tx_req.serialized = None + gc.collect() if tx_hash is None: return sanitize_tx_output(ack.tx) else: @@ -129,6 +135,7 @@ def request_tx_finish(tx_req: TxRequest): tx_req.details = None yield tx_req tx_req.serialized = None + gc.collect() # Data sanitizers diff --git a/core/src/apps/wallet/sign_tx/layout.py b/core/src/apps/wallet/sign_tx/layout.py index b39369c1b..7744383db 100644 --- a/core/src/apps/wallet/sign_tx/layout.py +++ b/core/src/apps/wallet/sign_tx/layout.py @@ -3,12 +3,8 @@ from ubinascii import hexlify from trezor import ui from trezor.messages import ButtonRequestType, OutputScriptType -from trezor.ui.text import Text from trezor.utils import chunks, format_amount -from apps.common.confirm import confirm, hold_to_confirm -from apps.wallet.sign_tx import addresses, omni - _LOCKTIME_TIMESTAMP_MIN_VALUE = const(500000000) @@ -25,6 +21,10 @@ def split_op_return(data): async def confirm_output(ctx, output, coin): + from trezor.ui.text import Text + from apps.common.confirm import confirm + from apps.wallet.sign_tx import addresses, omni + if output.script_type == OutputScriptType.PAYTOOPRETURN: data = output.op_return_data if omni.is_valid(data): @@ -48,7 +48,10 @@ async def confirm_output(ctx, output, coin): async def confirm_total(ctx, spending, fee, coin): - text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN) + from trezor.ui.text import Text + from apps.common.confirm import hold_to_confirm + + text = Text("Confirm transaction", ui.ICON_SEND, ui.GREEN) text.normal("Total amount:") text.bold(format_coin_amount(spending, coin)) text.normal("including fee:") @@ -57,7 +60,10 @@ async def confirm_total(ctx, spending, fee, coin): async def confirm_feeoverthreshold(ctx, fee, coin): - text = Text("High fee", ui.ICON_SEND, icon_color=ui.GREEN) + from trezor.ui.text import Text + from apps.common.confirm import confirm + + text = Text("High fee", ui.ICON_SEND, ui.GREEN) text.normal("The fee of") text.bold(format_coin_amount(fee, coin)) text.normal("is unexpectedly high.", "Continue?") @@ -65,13 +71,19 @@ async def confirm_feeoverthreshold(ctx, fee, coin): async def confirm_foreign_address(ctx, address_n, coin): - text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.RED) + from trezor.ui.text import Text + from apps.common.confirm import confirm + + text = Text("Confirm sending", ui.ICON_SEND, ui.RED) text.normal("Trying to spend", "coins from another chain.", "Continue?") return await confirm(ctx, text, ButtonRequestType.SignTx) async def confirm_nondefault_locktime(ctx, lock_time): - text = Text("Confirm locktime", ui.ICON_SEND, icon_color=ui.GREEN) + from trezor.ui.text import Text + from apps.common.confirm import confirm + + text = Text("Confirm locktime", ui.ICON_SEND, ui.GREEN) text.normal("Locktime for this transaction is set to") if lock_time < _LOCKTIME_TIMESTAMP_MIN_VALUE: text.normal("blockheight:") diff --git a/core/src/apps/wallet/sign_tx/scripts.py b/core/src/apps/wallet/sign_tx/scripts.py index 7a8407e9f..e7e303a62 100644 --- a/core/src/apps/wallet/sign_tx/scripts.py +++ b/core/src/apps/wallet/sign_tx/scripts.py @@ -171,7 +171,18 @@ def witness_p2wsh( # witness program + signatures + redeem script num_of_witness_items = 1 + len(signatures) + 1 - w = bytearray() + # length of the redeem script + pubkeys = multisig_get_pubkeys(multisig) + redeem_script_length = output_script_multisig_length(pubkeys, multisig.m) + + # length of the result + total_length = 1 + 1 # number of items, version + for s in signatures: + total_length += 1 + len(s) + 1 # length, signature, sighash + total_length += 1 + redeem_script_length # length, script + + w = empty_bytearray(total_length) + write_varint(w, num_of_witness_items) write_varint(w, 0) # version 0 witness program @@ -179,10 +190,9 @@ def witness_p2wsh( append_signature(w, s, sighash) # size of the witness included # redeem script - pubkeys = multisig_get_pubkeys(multisig) - redeem_script = output_script_multisig(pubkeys, multisig.m) - write_varint(w, len(redeem_script)) - write_bytes(w, redeem_script) + write_varint(w, redeem_script_length) + output_script_multisig(pubkeys, multisig.m, w) + return w @@ -198,13 +208,25 @@ def input_script_multisig( signature_index: int, sighash: int, coin: CoinInfo, -): +) -> bytearray: signatures = multisig.signatures # other signatures if len(signatures[signature_index]) > 0: raise ScriptsError("Invalid multisig parameters") signatures[signature_index] = signature # our signature - w = bytearray() + # length of the redeem script + pubkeys = multisig_get_pubkeys(multisig) + redeem_script_length = output_script_multisig_length(pubkeys, multisig.m) + + # length of the result + total_length = 0 + if not coin.decred: + total_length += 1 # OP_FALSE + for s in signatures: + total_length += 1 + len(s) + 1 # length, signature, sighash + total_length += 1 + redeem_script_length # length, script + + w = empty_bytearray(total_length) if not coin.decred: # Starts with OP_FALSE because of an old OP_CHECKMULTISIG bug, which @@ -217,15 +239,13 @@ def input_script_multisig( append_signature(w, s, sighash) # redeem script - pubkeys = multisig_get_pubkeys(multisig) - redeem_script = output_script_multisig(pubkeys, multisig.m) - write_op_push(w, len(redeem_script)) - write_bytes(w, redeem_script) + write_op_push(w, redeem_script_length) + output_script_multisig(pubkeys, multisig.m, w) return w -def output_script_multisig(pubkeys, m: int) -> bytearray: +def output_script_multisig(pubkeys, m: int, w: bytearray = None) -> bytearray: n = len(pubkeys) if n < 1 or n > 15 or m < 1 or m > 15: raise ScriptsError("Invalid multisig parameters") @@ -233,7 +253,8 @@ def output_script_multisig(pubkeys, m: int) -> bytearray: if len(pubkey) != 33: raise ScriptsError("Invalid multisig parameters") - w = bytearray() + if w is None: + w = empty_bytearray(output_script_multisig_length(pubkeys, m)) w.append(0x50 + m) # numbers 1 to 16 are pushed as 0x50 + value for p in pubkeys: append_pubkey(w, p) @@ -242,6 +263,10 @@ def output_script_multisig(pubkeys, m: int) -> bytearray: return w +def output_script_multisig_length(pubkeys, m: int) -> int: + return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig + + # OP_RETURN # === diff --git a/core/src/apps/wallet/sign_tx/signing.py b/core/src/apps/wallet/sign_tx/signing.py index 78019283e..bd021f138 100644 --- a/core/src/apps/wallet/sign_tx/signing.py +++ b/core/src/apps/wallet/sign_tx/signing.py @@ -1,3 +1,4 @@ +import gc from micropython import const from trezor import utils @@ -304,6 +305,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain): tx_ser.signature = signature # serialize input with correct signature + gc.collect() txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) @@ -365,6 +367,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain): tx_ser.signature = signature # serialize input with correct signature + gc.collect() txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) @@ -470,6 +473,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain): tx_ser.signature = signature # serialize input with correct signature + gc.collect() txi_sign.script_sig = input_derive_script( coin, txi_sign, key_sign_pub, signature ) @@ -812,8 +816,9 @@ def input_derive_script( if i.multisig: # p2wsh in p2sh pubkeys = multisig.multisig_get_pubkeys(i.multisig) - witness_script = scripts.output_script_multisig(pubkeys, i.multisig.m) - witness_script_hash = sha256(witness_script).digest() + witness_script_hasher = utils.HashWriter(sha256()) + scripts.output_script_multisig(pubkeys, i.multisig.m, witness_script_hasher) + witness_script_hash = witness_script_hasher.get_digest() return scripts.input_script_p2wsh_in_p2sh(witness_script_hash) # p2wpkh in p2sh