diff --git a/core/src/apps/bitcoin/scripts.py b/core/src/apps/bitcoin/scripts.py index 022a72a28..56ba8a125 100644 --- a/core/src/apps/bitcoin/scripts.py +++ b/core/src/apps/bitcoin/scripts.py @@ -1,10 +1,9 @@ from trezor import utils, wire from trezor.crypto import base58, cashaddr from trezor.crypto.hashlib import sha256 -from trezor.messages import InputScriptType, OutputScriptType +from trezor.messages import InputScriptType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.messages.TxInputType import TxInputType -from trezor.messages.TxOutputType import TxOutputType from apps.common import address_type from apps.common.coininfo import CoinInfo @@ -65,21 +64,18 @@ def input_derive_script( raise wire.ProcessError("Invalid script type") -def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes: - if txo.script_type == OutputScriptType.PAYTOOPRETURN: - return output_script_paytoopreturn(txo.op_return_data) - - if coin.bech32_prefix and txo.address.startswith(coin.bech32_prefix): +def output_derive_script(address: str, coin: CoinInfo) -> bytes: + if coin.bech32_prefix and address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh - witprog = common.decode_bech32_address(coin.bech32_prefix, txo.address) + witprog = common.decode_bech32_address(coin.bech32_prefix, address) return output_script_native_p2wpkh_or_p2wsh(witprog) if ( not utils.BITCOIN_ONLY and coin.cashaddr_prefix is not None - and txo.address.startswith(coin.cashaddr_prefix + ":") + and address.startswith(coin.cashaddr_prefix + ":") ): - prefix, addr = txo.address.split(":") + prefix, addr = address.split(":") version, data = cashaddr.decode(prefix, addr) if version == cashaddr.ADDRESS_TYPE_P2KH: version = coin.address_type @@ -90,7 +86,7 @@ def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes: raw_address = bytes([version]) + data else: try: - raw_address = base58.decode_check(txo.address, coin.b58_hash) + raw_address = base58.decode_check(address, coin.b58_hash) except ValueError: raise wire.DataError("Invalid address") diff --git a/core/src/apps/bitcoin/sign_tx/bitcoin.py b/core/src/apps/bitcoin/sign_tx/bitcoin.py index 240096a1d..14c2256af 100644 --- a/core/src/apps/bitcoin/sign_tx/bitcoin.py +++ b/core/src/apps/bitcoin/sign_tx/bitcoin.py @@ -180,7 +180,9 @@ class Bitcoin: if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES: raise wire.DataError("Wrong input script type") - prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index) + prev_amount, script_pubkey = await self.get_prevtx_output( + txi.prev_hash, txi.prev_index + ) if txi.amount is not None and prev_amount != txi.amount: raise wire.DataError("Invalid amount specified") @@ -328,7 +330,9 @@ class Bitcoin: script_pubkey = self.output_derive_script(txo) self.write_tx_output(self.serialized_tx, txo, script_pubkey) - async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int: + async def get_prevtx_output( + self, prev_hash: bytes, prev_index: int + ) -> Tuple[int, bytes]: amount_out = 0 # output amount # STAGE_REQUEST_2_PREV_META in legacy @@ -358,6 +362,7 @@ class Bitcoin: self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey) if i == prev_index: amount_out = txo_bin.amount + script_pubkey = txo_bin.script_pubkey self.check_prevtx_output(txo_bin) await self.write_prev_tx_footer(txh, tx, prev_hash) @@ -368,7 +373,7 @@ class Bitcoin: ): raise wire.ProcessError("Encountered invalid prev_hash") - return amount_out + return amount_out, script_pubkey def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None: # Validations to perform on the UTXO when checking the previous transaction output amount. @@ -425,6 +430,9 @@ class Bitcoin: # === def output_derive_script(self, txo: TxOutputType) -> bytes: + if txo.script_type == OutputScriptType.PAYTOOPRETURN: + return scripts.output_script_paytoopreturn(txo.op_return_data) + if txo.address_n: # change output try: @@ -438,7 +446,7 @@ class Bitcoin: input_script_type, self.coin, node, txo.multisig ) - return scripts.output_derive_script(txo, self.coin) + return scripts.output_derive_script(txo.address, self.coin) def output_is_change(self, txo: TxOutputType) -> bool: if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES: diff --git a/core/tests/test_apps.bitcoin.segwit.bip143.native_p2wpkh.py b/core/tests/test_apps.bitcoin.segwit.bip143.native_p2wpkh.py index fa8f413ba..1a53e42a0 100644 --- a/core/tests/test_apps.bitcoin.segwit.bip143.native_p2wpkh.py +++ b/core/tests/test_apps.bitcoin.segwit.bip143.native_p2wpkh.py @@ -69,7 +69,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase): for txo in [self.out1, self.out2]: txo_bin = TxOutputBinType() txo_bin.amount = txo.amount - script_pubkey = output_derive_script(txo, coin) + script_pubkey = output_derive_script(txo.address, coin) bip143.hash143_add_output(txo_bin, script_pubkey) outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double) @@ -86,7 +86,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase): for txo in [self.out1, self.out2]: txo_bin = TxOutputBinType() txo_bin.amount = txo.amount - script_pubkey = output_derive_script(txo, coin) + script_pubkey = output_derive_script(txo.address, coin) bip143.hash143_add_output(txo_bin, script_pubkey) # test data public key hash diff --git a/core/tests/test_apps.bitcoin.segwit.bip143.p2wpkh_in_p2sh.py b/core/tests/test_apps.bitcoin.segwit.bip143.p2wpkh_in_p2sh.py index 9cd71444a..ca88a0961 100644 --- a/core/tests/test_apps.bitcoin.segwit.bip143.p2wpkh_in_p2sh.py +++ b/core/tests/test_apps.bitcoin.segwit.bip143.p2wpkh_in_p2sh.py @@ -58,7 +58,7 @@ class TestSegwitBip143(unittest.TestCase): for txo in [self.out1, self.out2]: txo_bin = TxOutputBinType() txo_bin.amount = txo.amount - script_pubkey = output_derive_script(txo, coin) + script_pubkey = output_derive_script(txo.address, coin) bip143.hash143_add_output(txo_bin, script_pubkey) outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double) @@ -72,7 +72,7 @@ class TestSegwitBip143(unittest.TestCase): for txo in [self.out1, self.out2]: txo_bin = TxOutputBinType() txo_bin.amount = txo.amount - script_pubkey = output_derive_script(txo, coin) + script_pubkey = output_derive_script(txo.address, coin) bip143.hash143_add_output(txo_bin, script_pubkey) # test data public key hash diff --git a/core/tests/test_apps.bitcoin.txweight.py b/core/tests/test_apps.bitcoin.txweight.py index 3a4ec73ca..b173d3932 100644 --- a/core/tests/test_apps.bitcoin.txweight.py +++ b/core/tests/test_apps.bitcoin.txweight.py @@ -34,7 +34,7 @@ class TestCalculateTxWeight(unittest.TestCase): calculator = TxWeightCalculator(1, 1) calculator.add_input(inp1) - calculator.add_output(output_derive_script(out1, coin)) + calculator.add_output(output_derive_script(out1.address, coin)) serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000' tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4 @@ -73,8 +73,8 @@ class TestCalculateTxWeight(unittest.TestCase): calculator = TxWeightCalculator(1, 2) calculator.add_input(inp1) - calculator.add_output(output_derive_script(out1, coin)) - calculator.add_output(output_derive_script(out2, coin)) + calculator.add_output(output_derive_script(out1.address, coin)) + calculator.add_output(output_derive_script(out2.address, coin)) self.assertEqual(calculator.get_total(), 670) # non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560 @@ -113,8 +113,8 @@ class TestCalculateTxWeight(unittest.TestCase): calculator = TxWeightCalculator(1, 2) calculator.add_input(inp1) - calculator.add_output(output_derive_script(out1, coin)) - calculator.add_output(output_derive_script(out2, coin)) + calculator.add_output(output_derive_script(out1.address, coin)) + calculator.add_output(output_derive_script(out2.address, coin)) self.assertEqual(calculator.get_total(), 566) # non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456