1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 05:28:40 +00:00

core/sign_tx: cleanup

This commit is contained in:
Andrew Kozlik 2020-04-15 13:11:37 +02:00 committed by Andrew Kozlik
parent 987b70f1f5
commit 27e6720f3d
4 changed files with 80 additions and 74 deletions

View File

@ -89,11 +89,11 @@ class Bitcoinlike(signing.Bitcoin):
if not self.coin.negative_fee: if not self.coin.negative_fee:
super().on_negative_fee() super().on_negative_fee()
def get_raw_address(self, o: TxOutputType) -> bytes: def get_raw_address(self, txo: TxOutputType) -> bytes:
if self.coin.cashaddr_prefix is not None and o.address.startswith( if self.coin.cashaddr_prefix is not None and txo.address.startswith(
self.coin.cashaddr_prefix + ":" self.coin.cashaddr_prefix + ":"
): ):
prefix, addr = o.address.split(":") prefix, addr = txo.address.split(":")
version, data = cashaddr.decode(prefix, addr) version, data = cashaddr.decode(prefix, addr)
if version == cashaddr.ADDRESS_TYPE_P2KH: if version == cashaddr.ADDRESS_TYPE_P2KH:
version = self.coin.address_type version = self.coin.address_type
@ -103,7 +103,7 @@ class Bitcoinlike(signing.Bitcoin):
raise signing.SigningError("Unknown cashaddr address type") raise signing.SigningError("Unknown cashaddr address type")
return bytes([version]) + data return bytes([version]) + data
else: else:
return super().get_raw_address(o) return super().get_raw_address(txo)
def get_hash_type(self) -> int: def get_hash_type(self) -> int:
SIGHASH_FORKID = const(0x40) SIGHASH_FORKID = const(0x40)

View File

@ -183,8 +183,8 @@ class Decred(Bitcoin):
"Cannot use utxo that has script_version != 0", "Cannot use utxo that has script_version != 0",
) )
def write_tx_input(self, w: writers.Writer, i: TxInputType) -> None: def write_tx_input(self, w: writers.Writer, txi: TxInputType) -> None:
writers.write_tx_input_decred(w, i) writers.write_tx_input_decred(w, txi)
def write_sign_tx_header(self, w: writers.Writer, has_segwit: bool) -> None: def write_sign_tx_header(self, w: writers.Writer, has_segwit: bool) -> None:
writers.write_uint32(w, self.tx.version) # nVersion writers.write_uint32(w, self.tx.version) # nVersion

View File

@ -44,10 +44,15 @@ CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES = {
INTERNAL_INPUT_SCRIPT_TYPES = tuple(CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES.values()) INTERNAL_INPUT_SCRIPT_TYPES = tuple(CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES.values())
CHANGE_OUTPUT_SCRIPT_TYPES = tuple(CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES.keys()) CHANGE_OUTPUT_SCRIPT_TYPES = tuple(CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES.keys())
SEGWIT_INPUT_SCRIPT_TYPES = { SEGWIT_INPUT_SCRIPT_TYPES = (
InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDP2SHWITNESS,
InputScriptType.SPENDWITNESS, InputScriptType.SPENDWITNESS,
} )
NONSEGWIT_INPUT_SCRIPT_TYPES = (
InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG,
)
# Machine instructions # Machine instructions
# === # ===

View File

@ -209,15 +209,9 @@ class Bitcoin:
else: else:
self.multisig_fp.mismatch = True self.multisig_fp.mismatch = True
if txi.script_type in ( if input_is_segwit(txi):
InputScriptType.SPENDWITNESS,
InputScriptType.SPENDP2SHWITNESS,
):
await self.process_segwit_input(i, txi) await self.process_segwit_input(i, txi)
elif txi.script_type in ( elif input_is_nonsegwit(txi):
InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG,
):
await self.process_nonsegwit_input(i, txi) await self.process_nonsegwit_input(i, txi)
else: else:
raise SigningError(FailureType.DataError, "Wrong input script type") raise SigningError(FailureType.DataError, "Wrong input script type")
@ -251,23 +245,23 @@ class Bitcoin:
def on_negative_fee(self) -> None: def on_negative_fee(self) -> None:
raise SigningError(FailureType.NotEnoughFunds, "Not enough funds") raise SigningError(FailureType.NotEnoughFunds, "Not enough funds")
async def serialize_segwit_input(self, i_sign: int) -> None: async def serialize_segwit_input(self, i: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
if not input_is_segwit(txi_sign): if not input_is_segwit(txi):
raise SigningError( raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing" FailureType.ProcessError, "Transaction has changed during signing"
) )
self.input_check_wallet_path(txi_sign) self.input_check_wallet_path(txi)
# NOTE: No need to check the multisig fingerprint, because we won't be signing # NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS. # the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n, self.coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = node.public_key()
txi_sign.script_sig = self.input_derive_script(txi_sign, key_sign_pub) txi.script_sig = self.input_derive_script(txi, key_sign_pub)
self.write_tx_input(self.serialized_tx, txi_sign) self.write_tx_input(self.serialized_tx, txi)
async def sign_segwit_input(self, i: int) -> None: async def sign_segwit_input(self, i: int) -> None:
# STAGE_REQUEST_SEGWIT_WITNESS # STAGE_REQUEST_SEGWIT_WITNESS
@ -282,8 +276,8 @@ class Bitcoin:
) )
self.bip143_in -= txi.amount self.bip143_in -= txi.amount
key_sign = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n, self.coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = node.public_key()
hash143_hash = self.hash143.preimage_hash( hash143_hash = self.hash143.preimage_hash(
self.coin, self.coin,
self.tx, self.tx,
@ -292,16 +286,18 @@ class Bitcoin:
self.get_hash_type(), self.get_hash_type(),
) )
signature = ecdsa_sign(key_sign, hash143_hash) signature = ecdsa_sign(node, hash143_hash)
if txi.multisig: if txi.multisig:
# find out place of our signature based on the pubkey # find out place of our signature based on the pubkey
signature_index = multisig.multisig_pubkey_index(txi.multisig, key_sign_pub) signature_index = multisig.multisig_pubkey_index(txi.multisig, key_sign_pub)
self.serialized_tx[:] = scripts.witness_p2wsh( self.serialized_tx.extend(
txi.multisig, signature, signature_index, self.get_hash_type() scripts.witness_p2wsh(
txi.multisig, signature, signature_index, self.get_hash_type()
)
) )
else: else:
self.serialized_tx[:] = scripts.witness_p2wpkh( self.serialized_tx.extend(
signature, key_sign_pub, self.get_hash_type() scripts.witness_p2wpkh(signature, key_sign_pub, self.get_hash_type())
) )
self.tx_req.serialized.signature_index = i self.tx_req.serialized.signature_index = i
@ -323,8 +319,8 @@ class Bitcoin:
if i == i_sign: if i == i_sign:
txi_sign = txi txi_sign = txi
self.input_check_multisig_fingerprint(txi_sign) self.input_check_multisig_fingerprint(txi_sign)
key_sign = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n, self.coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = node.public_key()
# for the signing process the script_sig is equal # for the signing process the script_sig is equal
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
@ -370,7 +366,7 @@ class Bitcoin:
# compute the signature from the tx digest # compute the signature from the tx digest
signature = ecdsa_sign( signature = ecdsa_sign(
key_sign, writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double) node, writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double)
) )
# serialize input with correct signature # serialize input with correct signature
@ -415,13 +411,13 @@ class Bitcoin:
writers.write_varint(txh, tx.outputs_cnt) writers.write_varint(txh, tx.outputs_cnt)
for o in range(tx.outputs_cnt): for i in range(tx.outputs_cnt):
# STAGE_REQUEST_2_PREV_OUTPUT # STAGE_REQUEST_2_PREV_OUTPUT
txo_bin = await helpers.request_tx_output( txo_bin = await helpers.request_tx_output(
self.tx_req, o, self.coin, prev_hash self.tx_req, i, self.coin, prev_hash
) )
writers.write_tx_output(txh, txo_bin) writers.write_tx_output(txh, txo_bin)
if o == prev_index: if i == prev_index:
amount_out = txo_bin.amount amount_out = txo_bin.amount
self.check_prevtx_output(txo_bin) self.check_prevtx_output(txo_bin)
@ -447,8 +443,8 @@ class Bitcoin:
SIGHASH_ALL = const(0x01) SIGHASH_ALL = const(0x01)
return SIGHASH_ALL return SIGHASH_ALL
def write_tx_input(self, w: writers.Writer, i: TxInputType) -> None: def write_tx_input(self, w: writers.Writer, txi: TxInputType) -> None:
writers.write_tx_input(w, i) writers.write_tx_input(w, txi)
def write_sign_tx_header(self, w: writers.Writer, has_segwit: bool) -> None: def write_sign_tx_header(self, w: writers.Writer, has_segwit: bool) -> None:
self.write_tx_header(w, self.tx, has_segwit) self.write_tx_header(w, self.tx, has_segwit)
@ -473,22 +469,22 @@ class Bitcoin:
# TX Outputs # TX Outputs
# === # ===
def output_derive_script(self, o: TxOutputType) -> bytes: def output_derive_script(self, txo: TxOutputType) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN: if txo.script_type == OutputScriptType.PAYTOOPRETURN:
return scripts.output_script_paytoopreturn(o.op_return_data) return scripts.output_script_paytoopreturn(txo.op_return_data)
if o.address_n: if txo.address_n:
# change output # change output
o.address = self.get_address_for_change(o) txo.address = self.get_address_for_change(txo)
if self.coin.bech32_prefix and o.address.startswith(self.coin.bech32_prefix): if self.coin.bech32_prefix and txo.address.startswith(self.coin.bech32_prefix):
# p2wpkh or p2wsh # p2wpkh or p2wsh
witprog = addresses.decode_bech32_address( witprog = addresses.decode_bech32_address(
self.coin.bech32_prefix, o.address self.coin.bech32_prefix, txo.address
) )
return scripts.output_script_native_p2wpkh_or_p2wsh(witprog) return scripts.output_script_native_p2wpkh_or_p2wsh(witprog)
raw_address = self.get_raw_address(o) raw_address = self.get_raw_address(txo)
if address_type.check(self.coin.address_type, raw_address): if address_type.check(self.coin.address_type, raw_address):
# p2pkh # p2pkh
@ -504,55 +500,55 @@ class Bitcoin:
raise SigningError(FailureType.DataError, "Invalid address type") raise SigningError(FailureType.DataError, "Invalid address type")
def get_raw_address(self, o: TxOutputType) -> bytes: def get_raw_address(self, txo: TxOutputType) -> bytes:
try: try:
return base58.decode_check(o.address, self.coin.b58_hash) return base58.decode_check(txo.address, self.coin.b58_hash)
except ValueError: except ValueError:
raise SigningError(FailureType.DataError, "Invalid address") raise SigningError(FailureType.DataError, "Invalid address")
def get_address_for_change(self, o: TxOutputType) -> str: def get_address_for_change(self, txo: TxOutputType) -> str:
try: try:
input_script_type = helpers.CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES[ input_script_type = helpers.CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES[
o.script_type txo.script_type
] ]
except KeyError: except KeyError:
raise SigningError(FailureType.DataError, "Invalid script type") raise SigningError(FailureType.DataError, "Invalid script type")
node = self.keychain.derive(o.address_n, self.coin.curve_name) node = self.keychain.derive(txo.address_n, self.coin.curve_name)
return addresses.get_address(input_script_type, self.coin, node, o.multisig) return addresses.get_address(input_script_type, self.coin, node, txo.multisig)
def output_is_change(self, o: TxOutputType) -> bool: def output_is_change(self, txo: TxOutputType) -> bool:
if o.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES: if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
return False return False
if o.multisig and not self.multisig_fp.matches(o.multisig): if txo.multisig and not self.multisig_fp.matches(txo.multisig):
return False return False
return ( return (
self.wallet_path is not None self.wallet_path is not None
and self.wallet_path == o.address_n[:-_BIP32_WALLET_DEPTH] and self.wallet_path == txo.address_n[:-_BIP32_WALLET_DEPTH]
and o.address_n[-2] <= _BIP32_CHANGE_CHAIN and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and o.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
) )
# Tx Inputs # Tx Inputs
# === # ===
def input_derive_script( def input_derive_script(
self, i: TxInputType, pubkey: bytes, signature: bytes = None self, txi: TxInputType, pubkey: bytes, signature: bytes = None
) -> bytes: ) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS: if txi.script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
return scripts.input_script_p2pkh_or_p2sh( return scripts.input_script_p2pkh_or_p2sh(
pubkey, signature, self.get_hash_type() pubkey, signature, self.get_hash_type()
) )
if i.script_type == InputScriptType.SPENDP2SHWITNESS: if txi.script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh # p2wpkh or p2wsh using p2sh
if i.multisig: if txi.multisig:
# p2wsh in p2sh # p2wsh in p2sh
pubkeys = multisig.multisig_get_pubkeys(i.multisig) pubkeys = multisig.multisig_get_pubkeys(txi.multisig)
witness_script_hasher = self.create_hash_writer() witness_script_hasher = self.create_hash_writer()
scripts.write_output_script_multisig( scripts.write_output_script_multisig(
witness_script_hasher, pubkeys, i.multisig.m witness_script_hasher, pubkeys, txi.multisig.m
) )
witness_script_hash = witness_script_hasher.get_digest() witness_script_hash = witness_script_hasher.get_digest()
return scripts.input_script_p2wsh_in_p2sh(witness_script_hash) return scripts.input_script_p2wsh_in_p2sh(witness_script_hash)
@ -561,14 +557,18 @@ class Bitcoin:
return scripts.input_script_p2wpkh_in_p2sh( return scripts.input_script_p2wpkh_in_p2sh(
addresses.ecdsa_hash_pubkey(pubkey, self.coin) addresses.ecdsa_hash_pubkey(pubkey, self.coin)
) )
elif i.script_type == InputScriptType.SPENDWITNESS: elif txi.script_type == InputScriptType.SPENDWITNESS:
# native p2wpkh or p2wsh # native p2wpkh or p2wsh
return scripts.input_script_native_p2wpkh_or_p2wsh() return scripts.input_script_native_p2wpkh_or_p2wsh()
elif i.script_type == InputScriptType.SPENDMULTISIG: elif txi.script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig # p2sh multisig
signature_index = multisig.multisig_pubkey_index(i.multisig, pubkey) signature_index = multisig.multisig_pubkey_index(txi.multisig, pubkey)
return scripts.input_script_multisig( return scripts.input_script_multisig(
i.multisig, signature, signature_index, self.get_hash_type(), self.coin txi.multisig,
signature,
signature_index,
self.get_hash_type(),
self.coin,
) )
else: else:
raise SigningError(FailureType.ProcessError, "Invalid script type") raise SigningError(FailureType.ProcessError, "Invalid script type")
@ -603,11 +603,12 @@ class Bitcoin:
) )
def input_is_segwit(i: TxInputType) -> bool: def input_is_segwit(txi: TxInputType) -> bool:
return ( return txi.script_type in helpers.SEGWIT_INPUT_SCRIPT_TYPES
i.script_type == InputScriptType.SPENDWITNESS
or i.script_type == InputScriptType.SPENDP2SHWITNESS
) def input_is_nonsegwit(txi: TxInputType) -> bool:
return txi.script_type in helpers.NONSEGWIT_INPUT_SCRIPT_TYPES
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: