core: optimize memory usage in signing

pull/176/head
Jan Pochyla 5 years ago committed by Pavol Rusnak
parent a6e51434f2
commit c2f5174b43
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -1,4 +1,4 @@
from trezor import wire from trezor import utils, wire
from trezor.messages.MessageType import TxAck from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
@ -38,19 +38,29 @@ async def sign_tx(ctx, msg, keychain):
break break
res = await ctx.call(req, TxAck) res = await ctx.call(req, TxAck)
elif isinstance(req, helpers.UiConfirmOutput): elif isinstance(req, helpers.UiConfirmOutput):
mods = utils.unimport_begin()
res = await layout.confirm_output(ctx, req.output, req.coin) res = await layout.confirm_output(ctx, req.output, req.coin)
utils.unimport_end(mods)
progress.report_init() progress.report_init()
elif isinstance(req, helpers.UiConfirmTotal): elif isinstance(req, helpers.UiConfirmTotal):
mods = utils.unimport_begin()
res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin) res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
utils.unimport_end(mods)
progress.report_init() progress.report_init()
elif isinstance(req, helpers.UiConfirmFeeOverThreshold): elif isinstance(req, helpers.UiConfirmFeeOverThreshold):
mods = utils.unimport_begin()
res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
utils.unimport_end(mods)
progress.report_init() progress.report_init()
elif isinstance(req, helpers.UiConfirmNonDefaultLocktime): elif isinstance(req, helpers.UiConfirmNonDefaultLocktime):
mods = utils.unimport_begin()
res = await layout.confirm_nondefault_locktime(ctx, req.lock_time) res = await layout.confirm_nondefault_locktime(ctx, req.lock_time)
utils.unimport_end(mods)
progress.report_init() progress.report_init()
elif isinstance(req, helpers.UiConfirmForeignAddress): elif isinstance(req, helpers.UiConfirmForeignAddress):
mods = utils.unimport_begin()
res = await paths.show_path_warning(ctx, req.address_n) res = await paths.show_path_warning(ctx, req.address_n)
utils.unimport_end(mods)
progress.report_init() progress.report_init()
else: else:
raise TypeError("Invalid signing instruction") raise TypeError("Invalid signing instruction")

@ -1,3 +1,5 @@
import gc
from trezor.messages import InputScriptType from trezor.messages import InputScriptType
from trezor.messages.RequestType import ( from trezor.messages.RequestType import (
TXEXTRADATA, TXEXTRADATA,
@ -85,6 +87,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes = None):
tx_req.details.request_index = None tx_req.details.request_index = None
ack = yield tx_req ack = yield tx_req
tx_req.serialized = None tx_req.serialized = None
gc.collect()
return sanitize_tx_meta(ack.tx) return sanitize_tx_meta(ack.tx)
@ -100,6 +103,7 @@ def request_tx_extra_data(
tx_req.serialized = None tx_req.serialized = None
tx_req.details.extra_data_offset = None tx_req.details.extra_data_offset = None
tx_req.details.extra_data_len = None tx_req.details.extra_data_len = None
gc.collect()
return ack.tx.extra_data return ack.tx.extra_data
@ -109,6 +113,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
ack = yield tx_req ack = yield tx_req
tx_req.serialized = None tx_req.serialized = None
gc.collect()
return sanitize_tx_input(ack.tx) return sanitize_tx_input(ack.tx)
@ -118,6 +123,7 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes = None):
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
ack = yield tx_req ack = yield tx_req
tx_req.serialized = None tx_req.serialized = None
gc.collect()
if tx_hash is None: if tx_hash is None:
return sanitize_tx_output(ack.tx) return sanitize_tx_output(ack.tx)
else: else:
@ -129,6 +135,7 @@ def request_tx_finish(tx_req: TxRequest):
tx_req.details = None tx_req.details = None
yield tx_req yield tx_req
tx_req.serialized = None tx_req.serialized = None
gc.collect()
# Data sanitizers # Data sanitizers

@ -3,12 +3,8 @@ from ubinascii import hexlify
from trezor import ui from trezor import ui
from trezor.messages import ButtonRequestType, OutputScriptType from trezor.messages import ButtonRequestType, OutputScriptType
from trezor.ui.text import Text
from trezor.utils import chunks, format_amount from trezor.utils import chunks, format_amount
from apps.common.confirm import confirm, hold_to_confirm
from apps.wallet.sign_tx import addresses, omni
_LOCKTIME_TIMESTAMP_MIN_VALUE = const(500000000) _LOCKTIME_TIMESTAMP_MIN_VALUE = const(500000000)
@ -25,6 +21,10 @@ def split_op_return(data):
async def confirm_output(ctx, output, coin): async def confirm_output(ctx, output, coin):
from trezor.ui.text import Text
from apps.common.confirm import confirm
from apps.wallet.sign_tx import addresses, omni
if output.script_type == OutputScriptType.PAYTOOPRETURN: if output.script_type == OutputScriptType.PAYTOOPRETURN:
data = output.op_return_data data = output.op_return_data
if omni.is_valid(data): if omni.is_valid(data):
@ -48,7 +48,10 @@ async def confirm_output(ctx, output, coin):
async def confirm_total(ctx, spending, fee, coin): async def confirm_total(ctx, spending, fee, coin):
text = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN) from trezor.ui.text import Text
from apps.common.confirm import hold_to_confirm
text = Text("Confirm transaction", ui.ICON_SEND, ui.GREEN)
text.normal("Total amount:") text.normal("Total amount:")
text.bold(format_coin_amount(spending, coin)) text.bold(format_coin_amount(spending, coin))
text.normal("including fee:") text.normal("including fee:")
@ -57,7 +60,10 @@ async def confirm_total(ctx, spending, fee, coin):
async def confirm_feeoverthreshold(ctx, fee, coin): async def confirm_feeoverthreshold(ctx, fee, coin):
text = Text("High fee", ui.ICON_SEND, icon_color=ui.GREEN) from trezor.ui.text import Text
from apps.common.confirm import confirm
text = Text("High fee", ui.ICON_SEND, ui.GREEN)
text.normal("The fee of") text.normal("The fee of")
text.bold(format_coin_amount(fee, coin)) text.bold(format_coin_amount(fee, coin))
text.normal("is unexpectedly high.", "Continue?") text.normal("is unexpectedly high.", "Continue?")
@ -65,13 +71,19 @@ async def confirm_feeoverthreshold(ctx, fee, coin):
async def confirm_foreign_address(ctx, address_n, coin): async def confirm_foreign_address(ctx, address_n, coin):
text = Text("Confirm sending", ui.ICON_SEND, icon_color=ui.RED) from trezor.ui.text import Text
from apps.common.confirm import confirm
text = Text("Confirm sending", ui.ICON_SEND, ui.RED)
text.normal("Trying to spend", "coins from another chain.", "Continue?") text.normal("Trying to spend", "coins from another chain.", "Continue?")
return await confirm(ctx, text, ButtonRequestType.SignTx) return await confirm(ctx, text, ButtonRequestType.SignTx)
async def confirm_nondefault_locktime(ctx, lock_time): async def confirm_nondefault_locktime(ctx, lock_time):
text = Text("Confirm locktime", ui.ICON_SEND, icon_color=ui.GREEN) from trezor.ui.text import Text
from apps.common.confirm import confirm
text = Text("Confirm locktime", ui.ICON_SEND, ui.GREEN)
text.normal("Locktime for this transaction is set to") text.normal("Locktime for this transaction is set to")
if lock_time < _LOCKTIME_TIMESTAMP_MIN_VALUE: if lock_time < _LOCKTIME_TIMESTAMP_MIN_VALUE:
text.normal("blockheight:") text.normal("blockheight:")

@ -171,7 +171,18 @@ def witness_p2wsh(
# witness program + signatures + redeem script # witness program + signatures + redeem script
num_of_witness_items = 1 + len(signatures) + 1 num_of_witness_items = 1 + len(signatures) + 1
w = bytearray() # length of the redeem script
pubkeys = multisig_get_pubkeys(multisig)
redeem_script_length = output_script_multisig_length(pubkeys, multisig.m)
# length of the result
total_length = 1 + 1 # number of items, version
for s in signatures:
total_length += 1 + len(s) + 1 # length, signature, sighash
total_length += 1 + redeem_script_length # length, script
w = empty_bytearray(total_length)
write_varint(w, num_of_witness_items) write_varint(w, num_of_witness_items)
write_varint(w, 0) # version 0 witness program write_varint(w, 0) # version 0 witness program
@ -179,10 +190,9 @@ def witness_p2wsh(
append_signature(w, s, sighash) # size of the witness included append_signature(w, s, sighash) # size of the witness included
# redeem script # redeem script
pubkeys = multisig_get_pubkeys(multisig) write_varint(w, redeem_script_length)
redeem_script = output_script_multisig(pubkeys, multisig.m) output_script_multisig(pubkeys, multisig.m, w)
write_varint(w, len(redeem_script))
write_bytes(w, redeem_script)
return w return w
@ -198,13 +208,25 @@ def input_script_multisig(
signature_index: int, signature_index: int,
sighash: int, sighash: int,
coin: CoinInfo, coin: CoinInfo,
): ) -> bytearray:
signatures = multisig.signatures # other signatures signatures = multisig.signatures # other signatures
if len(signatures[signature_index]) > 0: if len(signatures[signature_index]) > 0:
raise ScriptsError("Invalid multisig parameters") raise ScriptsError("Invalid multisig parameters")
signatures[signature_index] = signature # our signature signatures[signature_index] = signature # our signature
w = bytearray() # length of the redeem script
pubkeys = multisig_get_pubkeys(multisig)
redeem_script_length = output_script_multisig_length(pubkeys, multisig.m)
# length of the result
total_length = 0
if not coin.decred:
total_length += 1 # OP_FALSE
for s in signatures:
total_length += 1 + len(s) + 1 # length, signature, sighash
total_length += 1 + redeem_script_length # length, script
w = empty_bytearray(total_length)
if not coin.decred: if not coin.decred:
# Starts with OP_FALSE because of an old OP_CHECKMULTISIG bug, which # Starts with OP_FALSE because of an old OP_CHECKMULTISIG bug, which
@ -217,15 +239,13 @@ def input_script_multisig(
append_signature(w, s, sighash) append_signature(w, s, sighash)
# redeem script # redeem script
pubkeys = multisig_get_pubkeys(multisig) write_op_push(w, redeem_script_length)
redeem_script = output_script_multisig(pubkeys, multisig.m) output_script_multisig(pubkeys, multisig.m, w)
write_op_push(w, len(redeem_script))
write_bytes(w, redeem_script)
return w return w
def output_script_multisig(pubkeys, m: int) -> bytearray: def output_script_multisig(pubkeys, m: int, w: bytearray = None) -> bytearray:
n = len(pubkeys) n = len(pubkeys)
if n < 1 or n > 15 or m < 1 or m > 15: if n < 1 or n > 15 or m < 1 or m > 15:
raise ScriptsError("Invalid multisig parameters") raise ScriptsError("Invalid multisig parameters")
@ -233,7 +253,8 @@ def output_script_multisig(pubkeys, m: int) -> bytearray:
if len(pubkey) != 33: if len(pubkey) != 33:
raise ScriptsError("Invalid multisig parameters") raise ScriptsError("Invalid multisig parameters")
w = bytearray() if w is None:
w = empty_bytearray(output_script_multisig_length(pubkeys, m))
w.append(0x50 + m) # numbers 1 to 16 are pushed as 0x50 + value w.append(0x50 + m) # numbers 1 to 16 are pushed as 0x50 + value
for p in pubkeys: for p in pubkeys:
append_pubkey(w, p) append_pubkey(w, p)
@ -242,6 +263,10 @@ def output_script_multisig(pubkeys, m: int) -> bytearray:
return w return w
def output_script_multisig_length(pubkeys, m: int) -> int:
return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig
# OP_RETURN # OP_RETURN
# === # ===

@ -1,3 +1,4 @@
import gc
from micropython import const from micropython import const
from trezor import utils from trezor import utils
@ -304,6 +305,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx_ser.signature = signature tx_ser.signature = signature
# serialize input with correct signature # serialize input with correct signature
gc.collect()
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature coin, txi_sign, key_sign_pub, signature
) )
@ -365,6 +367,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx_ser.signature = signature tx_ser.signature = signature
# serialize input with correct signature # serialize input with correct signature
gc.collect()
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature coin, txi_sign, key_sign_pub, signature
) )
@ -470,6 +473,7 @@ async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx_ser.signature = signature tx_ser.signature = signature
# serialize input with correct signature # serialize input with correct signature
gc.collect()
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature coin, txi_sign, key_sign_pub, signature
) )
@ -812,8 +816,9 @@ def input_derive_script(
if i.multisig: if i.multisig:
# p2wsh in p2sh # p2wsh in p2sh
pubkeys = multisig.multisig_get_pubkeys(i.multisig) pubkeys = multisig.multisig_get_pubkeys(i.multisig)
witness_script = scripts.output_script_multisig(pubkeys, i.multisig.m) witness_script_hasher = utils.HashWriter(sha256())
witness_script_hash = sha256(witness_script).digest() scripts.output_script_multisig(pubkeys, i.multisig.m, witness_script_hasher)
witness_script_hash = witness_script_hasher.get_digest()
return scripts.input_script_p2wsh_in_p2sh(witness_script_hash) return scripts.input_script_p2wsh_in_p2sh(witness_script_hash)
# p2wpkh in p2sh # p2wpkh in p2sh

Loading…
Cancel
Save