From 514f2ac649184955a044784367af04498204b8db Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Thu, 9 Apr 2020 13:47:48 +0200 Subject: [PATCH] core/sign_tx: Refactor to use template method. --- core/src/apps/wallet/sign_tx/bitcoinlike.py | 26 +-- core/src/apps/wallet/sign_tx/decred.py | 50 +++-- core/src/apps/wallet/sign_tx/signing.py | 224 +++++++++++--------- core/src/apps/wallet/sign_tx/zcash.py | 10 +- 4 files changed, 173 insertions(+), 137 deletions(-) diff --git a/core/src/apps/wallet/sign_tx/bitcoinlike.py b/core/src/apps/wallet/sign_tx/bitcoinlike.py index dddbf4cbf..5ad0271e1 100644 --- a/core/src/apps/wallet/sign_tx/bitcoinlike.py +++ b/core/src/apps/wallet/sign_tx/bitcoinlike.py @@ -7,7 +7,6 @@ from trezor.messages.SignTx import SignTx from trezor.messages.TransactionType import TransactionType from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputType import TxOutputType -from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from apps.wallet.sign_tx import addresses, helpers, multisig, signing, writers @@ -16,20 +15,20 @@ if False: class Bitcoinlike(signing.Bitcoin): - async def phase1_process_segwit_input(self, i: int, txi: TxInputType) -> None: + async def process_segwit_input(self, i: int, txi: TxInputType) -> None: if not self.coin.segwit: raise signing.SigningError( FailureType.DataError, "Segwit not enabled on this coin" ) - await super().phase1_process_segwit_input(i, txi) + await super().process_segwit_input(i, txi) - async def phase1_process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: + async def process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: if self.coin.force_bip143: - await self.phase1_process_bip143_input(i, txi) + await self.process_bip143_input(i, txi) else: - await super().phase1_process_nonsegwit_input(i, txi) + await super().process_nonsegwit_input(i, txi) - async def phase1_process_bip143_input(self, i: int, txi: TxInputType) -> None: + async def process_bip143_input(self, i: int, txi: TxInputType) -> None: if not txi.amount: raise signing.SigningError( FailureType.DataError, "Expected input with amount" @@ -38,13 +37,13 @@ class Bitcoinlike(signing.Bitcoin): self.bip143_in += txi.amount self.total_in += txi.amount - async def phase2_sign_nonsegwit_input(self, i_sign: int) -> None: + async def sign_nonsegwit_input(self, i_sign: int) -> None: if self.coin.force_bip143: - await self.phase2_sign_bip143_input(i_sign) + await self.sign_bip143_input(i_sign) else: - await super().phase2_sign_nonsegwit_input(i_sign) + await super().sign_nonsegwit_input(i_sign) - async def phase2_sign_bip143_input(self, i_sign: int) -> None: + async def sign_bip143_input(self, i_sign: int) -> None: # STAGE_REQUEST_SEGWIT_INPUT txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) self.input_check_wallet_path(txi_sign) @@ -87,7 +86,10 @@ class Bitcoinlike(signing.Bitcoin): if i_sign == 0: # serializing first input => prepend headers self.write_sign_tx_header(w_txi_sign, True in self.segwit.values()) writers.write_tx_input(w_txi_sign, txi_sign) - self.tx_req.serialized = TxRequestSerializedType(i_sign, signature, w_txi_sign) + self.tx_ser.signature_index = i_sign + self.tx_ser.signature = signature + self.tx_ser.serialized_tx = w_txi_sign + self.tx_req.serialized = self.tx_ser def on_negative_fee(self) -> None: # some coins require negative fees for reward TX diff --git a/core/src/apps/wallet/sign_tx/decred.py b/core/src/apps/wallet/sign_tx/decred.py index 76fa39dd9..db73b4992 100644 --- a/core/src/apps/wallet/sign_tx/decred.py +++ b/core/src/apps/wallet/sign_tx/decred.py @@ -8,7 +8,6 @@ from trezor.messages.TransactionType import TransactionType from trezor.messages.TxInputType import TxInputType from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxOutputType import TxOutputType -from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.utils import HashWriter, ensure from apps.common import coininfo, seed @@ -66,7 +65,7 @@ class Decred(Bitcoin): super().initialize(tx, keychain, coin) # This is required because the last serialized output obtained in - # phase 1 will only be sent to the client in phase 2 + # step 2 will only be sent to the client in step 4 self.last_output_bytes = None # type: bytearray def init_hash143(self) -> None: @@ -75,19 +74,26 @@ class Decred(Bitcoin): def create_hash_writer(self) -> HashWriter: return HashWriter(blake256()) - async def phase1(self) -> None: - await super().phase1() + async def step2_confirm_outputs(self) -> None: + await super().step2_confirm_outputs() self.hash143.add_locktime_expiry(self.tx) - async def phase1_process_input(self, i: int, txi: TxInputType) -> None: - await super().phase1_process_input(i, txi) + async def process_input(self, i: int, txi: TxInputType) -> None: + await super().process_input(i, txi) + + # Decred serializes inputs early. w_txi = writers.empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash)) if i == 0: # serializing first input => prepend headers self.write_sign_tx_header(w_txi, False) + self.write_tx_input(w_txi, txi) - self.tx_req.serialized = TxRequestSerializedType(None, None, w_txi) - async def phase1_confirm_output( + self.tx_ser.signature_index = None + self.tx_ser.signature = None + self.tx_ser.serialized_tx = w_txi + self.tx_req.serialized = self.tx_ser + + async def confirm_output( self, i: int, txo: TxOutputType, txo_bin: TxOutputBinType ) -> None: if txo.decred_script_version is not None and txo.decred_script_version != 0: @@ -103,12 +109,17 @@ class Decred(Bitcoin): self.hash143.add_output_count(self.tx) writers.write_tx_output(w_txo_bin, txo_bin) - self.tx_req.serialized = TxRequestSerializedType(serialized_tx=w_txo_bin) + + self.tx_ser.signature_index = None + self.tx_ser.signature = None + self.tx_ser.serialized_tx = w_txo_bin + self.tx_req.serialized = self.tx_ser + self.last_output_bytes = w_txo_bin - await super().phase1_confirm_output(i, txo, txo_bin) + await super().confirm_output(i, txo, txo_bin) - async def phase2(self) -> None: + async def step4_serialize_inputs(self) -> None: self.tx_req.serialized = None prefix_hash = self.hash143.get_prefix_hash() @@ -178,11 +189,20 @@ class Decred(Bitcoin): writers.write_varint(w_txi_sign, self.tx.inputs_count) writers.write_tx_input_decred_witness(w_txi_sign, txi_sign) - self.tx_req.serialized = TxRequestSerializedType( - i_sign, signature, w_txi_sign - ) - await helpers.request_tx_finish(self.tx_req) + self.tx_ser.signature_index = i_sign + self.tx_ser.signature = signature + self.tx_ser.serialized_tx = w_txi_sign + self.tx_req.serialized = self.tx_ser + + async def step5_serialize_outputs(self) -> None: + pass + + async def step6_sign_segwit_inputs(self) -> None: + pass + + def write_sign_tx_footer(self, w: writers.Writer) -> None: + pass def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None: if ( diff --git a/core/src/apps/wallet/sign_tx/signing.py b/core/src/apps/wallet/sign_tx/signing.py index 7794b2343..10744ba55 100644 --- a/core/src/apps/wallet/sign_tx/signing.py +++ b/core/src/apps/wallet/sign_tx/signing.py @@ -28,7 +28,7 @@ from apps.wallet.sign_tx import ( ) if False: - from typing import Dict, List, Optional, Tuple, Union + from typing import Dict, List, Optional, Union # the number of bip32 levels used in a wallet (chain and address) _BIP32_WALLET_DEPTH = const(2) @@ -60,16 +60,27 @@ class Bitcoin: progress.init(self.tx.inputs_count, self.tx.outputs_count) - # Phase 1 - # - check inputs, previous transactions, and outputs - # - ask for confirmations - # - check fee - await self.phase1() + # Add inputs to hash143 and h_confirmed and compute the sum of input amounts. + await self.step1_process_inputs() - # Phase 2 - # - sign inputs - # - check that nothing changed - await self.phase2() + # Add outputs to hash143 and h_confirmed, check previous transaction output + # amounts, confirm outputs and compute sum of output amounts. + await self.step2_confirm_outputs() + + # Check fee, confirm lock_time and total. + await self.step3_confirm_tran() + + # Check that inputs are unchanged. Serialize inputs and sign the non-segwit ones. + await self.step4_serialize_inputs() + + # Serialize outputs. + await self.step5_serialize_outputs() + + # Sign segwit inputs and serialize witness data. + await self.step6_sign_segwit_inputs() + + # Write footer and send remaining data. + await self.step7_finish() def initialize( self, tx: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo @@ -78,27 +89,29 @@ class Bitcoin: self.tx = helpers.sanitize_sign_tx(tx, self.coin) self.keychain = keychain - self.multisig_fp = ( - multisig.MultisigFingerprint() - ) # control checksum of multisig inputs - self.wallet_path = ( - [] - ) # type: Optional[List[int]] # common prefix of input paths - self.bip143_in = 0 # sum of segwit input amounts - self.segwit = ( - {} - ) # type: Dict[int, bool] # dict of booleans stating if input is segwit + # checksum of multisig inputs, used to validate change-output + self.multisig_fp = multisig.MultisigFingerprint() + + # common prefix of input paths, used to validate change-output + self.wallet_path = [] # type: Optional[List[int]] + + # dict of booleans stating if input is segwit + self.segwit = {} # type: Dict[int, bool] + self.total_in = 0 # sum of input amounts + self.bip143_in = 0 # sum of segwit input amounts self.total_out = 0 # sum of output amounts self.change_out = 0 # change output amount + self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) self.tx_req = TxRequest() self.tx_req.details = TxRequestDetailsType() + self.tx_ser = TxRequestSerializedType() - # h_first is used to make sure the inputs and outputs streamed in Phase 1 - # are the same as in Phase 2 when signing legacy inputs. it is thus not required to fully hash the - # tx, as the SignTx info is streamed only once - self.h_first = self.create_hash_writer() # not a real tx hash + # h_confirmed is used to make sure that the inputs and outputs streamed for + # confirmation in Steps 1 and 2 are the same as the ones streamed for signing + # legacy inputs in Step 4. + self.h_confirmed = self.create_hash_writer() # not a real tx hash self.init_hash143() @@ -108,36 +121,32 @@ class Bitcoin: def create_hash_writer(self) -> utils.HashWriter: return utils.HashWriter(sha256()) - async def phase1(self) -> None: - weight = tx_weight.TxWeightCalculator( - self.tx.inputs_count, self.tx.outputs_count - ) - - # compute sum of input amounts (total_in) - # add inputs to hash143 and h_first + async def step1_process_inputs(self) -> None: for i in range(self.tx.inputs_count): # STAGE_REQUEST_1_INPUT progress.advance() txi = await helpers.request_tx_input(self.tx_req, i, self.coin) - weight.add_input(txi) - await self.phase1_process_input(i, txi) + self.weight.add_input(txi) + await self.process_input(i, txi) + async def step2_confirm_outputs(self) -> None: txo_bin = TxOutputBinType() for i in range(self.tx.outputs_count): # STAGE_REQUEST_3_OUTPUT txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo_bin.amount = txo.amount txo_bin.script_pubkey = self.output_derive_script(txo) - weight.add_output(txo_bin.script_pubkey) - await self.phase1_confirm_output(i, txo, txo_bin) + self.weight.add_output(txo_bin.script_pubkey) + await self.confirm_output(i, txo, txo_bin) + async def step3_confirm_tran(self) -> None: fee = self.total_in - self.total_out if fee < 0: self.on_negative_fee() # fee > (coin.maxfee per byte * tx size) - if fee > (self.coin.maxfee_kb / 1000) * (weight.get_total() / 4): + if fee > (self.coin.maxfee_kb / 1000) * (self.weight.get_total() / 4): if not await helpers.confirm_feeoverthreshold(fee, self.coin): raise SigningError(FailureType.ActionCancelled, "Signing cancelled") @@ -150,9 +159,43 @@ class Bitcoin: ): raise SigningError(FailureType.ActionCancelled, "Total cancelled") - async def phase1_process_input(self, i: int, txi: TxInputType) -> None: + async def step4_serialize_inputs(self) -> None: + self.tx_req.serialized = None + + for i in range(self.tx.inputs_count): + progress.advance() + if self.segwit[i]: + await self.serialize_segwit_input(i) + else: + await self.sign_nonsegwit_input(i) + + async def step5_serialize_outputs(self) -> None: + for i in range(self.tx.outputs_count): + progress.advance() + await self.serialize_output(i) + + async def step6_sign_segwit_inputs(self) -> None: + any_segwit = True in self.segwit.values() + for i in range(self.tx.inputs_count): + progress.advance() + if self.segwit[i]: + await self.sign_segwit_input(i) + elif any_segwit: + # TODO what if a non-segwit input follows after a segwit input? + self.tx_ser.serialized_tx += bytearray( + 1 + ) # empty witness for non-segwit inputs + self.tx_ser.signature_index = None + self.tx_ser.signature = None + self.tx_req.serialized = self.tx_ser + + async def step7_finish(self) -> None: + self.write_sign_tx_footer(self.tx_ser.serialized_tx) + await helpers.request_tx_finish(self.tx_req) + + async def process_input(self, i: int, txi: TxInputType) -> None: self.input_extract_wallet_path(txi) - writers.write_tx_input_check(self.h_first, txi) + writers.write_tx_input_check(self.h_confirmed, txi) self.hash143.add_prevouts(txi) # all inputs are included (non-segwit as well) self.hash143.add_sequence(txi) @@ -168,29 +211,29 @@ class Bitcoin: InputScriptType.SPENDWITNESS, InputScriptType.SPENDP2SHWITNESS, ): - await self.phase1_process_segwit_input(i, txi) + await self.process_segwit_input(i, txi) elif txi.script_type in ( InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG, ): - await self.phase1_process_nonsegwit_input(i, txi) + await self.process_nonsegwit_input(i, txi) else: raise SigningError(FailureType.DataError, "Wrong input script type") - async def phase1_process_segwit_input(self, i: int, txi: TxInputType) -> None: + async def process_segwit_input(self, i: int, txi: TxInputType) -> None: if not txi.amount: raise SigningError(FailureType.DataError, "Segwit input without amount") self.segwit[i] = True self.bip143_in += txi.amount self.total_in += txi.amount - async def phase1_process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: + async def process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: self.segwit[i] = False self.total_in += await self.get_prevtx_output_value( txi.prev_hash, txi.prev_index ) - async def phase1_confirm_output( + async def confirm_output( self, i: int, txo: TxOutputType, txo_bin: TxOutputBinType ) -> None: if self.change_out == 0 and self.output_is_change(txo): @@ -199,57 +242,14 @@ class Bitcoin: elif not await helpers.confirm_output(txo, self.coin): raise SigningError(FailureType.ActionCancelled, "Output cancelled") - writers.write_tx_output(self.h_first, txo_bin) + writers.write_tx_output(self.h_confirmed, txo_bin) self.hash143.add_output(txo_bin) self.total_out += txo_bin.amount def on_negative_fee(self) -> None: raise SigningError(FailureType.NotEnoughFunds, "Not enough funds") - async def phase2(self) -> None: - self.tx_req.serialized = None - - # Serialize inputs and sign non-segwit inputs. - for i in range(self.tx.inputs_count): - progress.advance() - if self.segwit[i]: - await self.phase2_serialize_segwit_input(i) - else: - await self.phase2_sign_nonsegwit_input(i) - - # Serialize outputs. - tx_ser = TxRequestSerializedType() - for i in range(self.tx.outputs_count): - # STAGE_REQUEST_5_OUTPUT - progress.advance() - tx_ser.serialized_tx = await self.phase2_serialize_output(i) - self.tx_req.serialized = tx_ser - - # Sign segwit inputs. - any_segwit = True in self.segwit.values() - for i in range(self.tx.inputs_count): - progress.advance() - if self.segwit[i]: - # STAGE_REQUEST_SEGWIT_WITNESS - witness, signature = await self.phase2_sign_segwit_input(i) - tx_ser.serialized_tx = witness - tx_ser.signature_index = i - tx_ser.signature = signature - elif any_segwit: - # TODO what if a non-segwit input follows after a segwit input? - tx_ser.serialized_tx += bytearray( - 1 - ) # empty witness for non-segwit inputs - tx_ser.signature_index = None - tx_ser.signature = None - - self.tx_req.serialized = tx_ser - - self.write_sign_tx_footer(tx_ser.serialized_tx) - - await helpers.request_tx_finish(self.tx_req) - - async def phase2_serialize_segwit_input(self, i_sign: int) -> None: + async def serialize_segwit_input(self, i_sign: int) -> None: # STAGE_REQUEST_SEGWIT_INPUT txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) @@ -271,9 +271,13 @@ class Bitcoin: if i_sign == 0: # serializing first input => prepend headers self.write_sign_tx_header(w_txi, True) self.write_tx_input(w_txi, txi_sign) - self.tx_req.serialized = TxRequestSerializedType(serialized_tx=w_txi) + self.tx_ser.signature_index = None + self.tx_ser.signature = None + self.tx_ser.serialized_tx = w_txi + self.tx_req.serialized = self.tx_ser - async def phase2_sign_segwit_input(self, i: int) -> Tuple[bytearray, bytes]: + async def sign_segwit_input(self, i: int) -> None: + # STAGE_REQUEST_SEGWIT_WITNESS txi = await helpers.request_tx_input(self.tx_req, i, self.coin) self.input_check_wallet_path(txi) @@ -307,13 +311,16 @@ class Bitcoin: signature, key_sign_pub, self.get_hash_type() ) - return witness, signature + self.tx_ser.signature_index = i + self.tx_ser.signature = signature + self.tx_ser.serialized_tx = witness + self.tx_req.serialized = self.tx_ser - async def phase2_sign_nonsegwit_input(self, i_sign: int) -> None: + async def sign_nonsegwit_input(self, i_sign: int) -> None: # hash of what we are signing with this input h_sign = self.create_hash_writer() - # same as h_first, checked before signing the digest - h_second = self.create_hash_writer() + # should come out the same as h_confirmed, checked before signing the digest + h_check = self.create_hash_writer() self.write_sign_tx_header(h_sign, has_segwit=False) @@ -321,7 +328,7 @@ class Bitcoin: # STAGE_REQUEST_4_INPUT txi = await helpers.request_tx_input(self.tx_req, i, self.coin) self.input_check_wallet_path(txi) - writers.write_tx_input_check(h_second, txi) + writers.write_tx_input_check(h_check, txi) if i == i_sign: txi_sign = txi self.input_check_multisig_fingerprint(txi_sign) @@ -354,14 +361,14 @@ class Bitcoin: txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo_bin.amount = txo.amount txo_bin.script_pubkey = self.output_derive_script(txo) - writers.write_tx_output(h_second, txo_bin) + writers.write_tx_output(h_check, txo_bin) writers.write_tx_output(h_sign, txo_bin) writers.write_uint32(h_sign, self.tx.lock_time) writers.write_uint32(h_sign, self.get_hash_type()) # check the control digests - if writers.get_tx_hash(self.h_first, False) != writers.get_tx_hash(h_second): + if self.h_confirmed.get_digest() != h_check.get_digest(): raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) @@ -375,7 +382,7 @@ class Bitcoin: key_sign, writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double) ) - # serialize input wittx_reqh correct signature + # serialize input with correct signature gc.collect() txi_sign.script_sig = self.input_derive_script( txi_sign, key_sign_pub, signature @@ -386,24 +393,31 @@ class Bitcoin: if i_sign == 0: # serializing first input => prepend headers self.write_sign_tx_header(w_txi_sign, True in self.segwit.values()) self.write_tx_input(w_txi_sign, txi_sign) - self.tx_req.serialized = TxRequestSerializedType(i_sign, signature, w_txi_sign) - async def phase2_serialize_output(self, i: int) -> bytearray: + self.tx_ser.signature_index = i_sign + self.tx_ser.signature = signature + self.tx_ser.serialized_tx = w_txi_sign + self.tx_req.serialized = self.tx_ser + + async def serialize_output(self, i: int) -> None: + # STAGE_REQUEST_5_OUTPUT txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo_bin = TxOutputBinType() txo_bin.amount = txo.amount txo_bin.script_pubkey = self.output_derive_script(txo) - # serialize output w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) if i == 0: # serializing first output => prepend outputs count writers.write_varint(w_txo_bin, self.tx.outputs_count) writers.write_tx_output(w_txo_bin, txo_bin) - return w_txo_bin + self.tx_ser.signature_index = None + self.tx_ser.signature = None + self.tx_ser.serialized_tx = w_txo_bin + self.tx_req.serialized = self.tx_ser async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int: - amount_out = 0 # sum of output amounts + amount_out = 0 # output amount # STAGE_REQUEST_2_PREV_META tx = await helpers.request_tx_meta(self.tx_req, self.coin, prev_hash) @@ -597,7 +611,7 @@ class Bitcoin: def input_check_wallet_path(self, txi: TxInputType) -> None: if self.wallet_path is None: - return # there was a mismatch in Phase 1, ignore it now + return # there was a mismatch in Step 1, ignore it now address_n = txi.address_n[:-_BIP32_WALLET_DEPTH] if self.wallet_path != address_n: raise SigningError( @@ -606,7 +620,7 @@ class Bitcoin: def input_check_multisig_fingerprint(self, txi: TxInputType) -> None: if self.multisig_fp.mismatch is False: - # All inputs in Phase 1 had matching multisig fingerprints, allowing a multisig change-output. + # All inputs in Step 1 had matching multisig fingerprints, allowing a multisig change-output. if not txi.multisig or not self.multisig_fp.matches(txi.multisig): # This input no longer has a matching multisig fingerprint. raise SigningError( diff --git a/core/src/apps/wallet/sign_tx/zcash.py b/core/src/apps/wallet/sign_tx/zcash.py index 5884b7706..f7e028fa2 100644 --- a/core/src/apps/wallet/sign_tx/zcash.py +++ b/core/src/apps/wallet/sign_tx/zcash.py @@ -201,11 +201,11 @@ class Overwintered(Bitcoinlike): "Unsupported version for overwintered transaction", ) - async def phase1_process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: - await self.phase1_process_bip143_input(i, txi) + async def process_nonsegwit_input(self, i: int, txi: TxInputType) -> None: + await self.process_bip143_input(i, txi) - async def phase2_sign_nonsegwit_input(self, i_sign: int) -> None: - await self.phase2_sign_bip143_input(i_sign) + async def sign_nonsegwit_input(self, i_sign: int) -> None: + await self.sign_bip143_input(i_sign) def write_tx_header( self, w: Writer, tx: Union[SignTx, TransactionType], has_segwit: bool @@ -215,7 +215,7 @@ class Overwintered(Bitcoinlike): write_uint32(w, tx.version_group_id) # nVersionGroupId def write_sign_tx_footer(self, w: Writer) -> None: - super().write_sign_tx_footer(w) + write_uint32(w, self.tx.lock_time) if self.tx.version == 3: write_uint32(w, self.tx.expiry) # expiryHeight