1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 23:38:09 +00:00

chore(core): In apps.bitcoin create a separate class for transaction information.

This commit is contained in:
Andrew Kozlik 2020-09-30 18:50:25 +02:00 committed by Andrew Kozlik
parent 469c131678
commit bd3fe1d789
8 changed files with 270 additions and 175 deletions

View File

@ -9,8 +9,8 @@ from trezor.utils import ensure
if False: if False:
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from typing import Dict from typing import Dict
from trezor.messages.TxInputType import EnumTypeInputScriptType from trezor.messages.TxInput import EnumTypeInputScriptType, TxInput
from trezor.messages.TxOutputType import EnumTypeOutputScriptType from trezor.messages.TxOutput import EnumTypeOutputScriptType
BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet") BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
@ -85,3 +85,19 @@ def decode_bech32_address(prefix: str, address: str) -> bytes:
raise wire.ProcessError("Invalid address witness program") raise wire.ProcessError("Invalid address witness program")
assert raw is not None assert raw is not None
return bytes(raw) return bytes(raw)
def input_is_segwit(txi: TxInput) -> bool:
return txi.script_type in SEGWIT_INPUT_SCRIPT_TYPES or (
txi.script_type == InputScriptType.EXTERNAL and txi.witness is not None
)
def input_is_nonsegwit(txi: TxInput) -> bool:
return txi.script_type in NONSEGWIT_INPUT_SCRIPT_TYPES or (
txi.script_type == InputScriptType.EXTERNAL and txi.witness is None
)
def input_is_external(txi: TxInput) -> bool:
return txi.script_type == InputScriptType.EXTERNAL

View File

@ -7,6 +7,7 @@ from apps.common import safety_checks
from .. import addresses from .. import addresses
from ..authorization import FEE_PER_ANONYMITY_DECIMALS from ..authorization import FEE_PER_ANONYMITY_DECIMALS
from . import helpers, tx_weight from . import helpers, tx_weight
from .tx_info import TxInfo
if False: if False:
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
@ -17,9 +18,6 @@ if False:
from ..authorization import CoinJoinAuthorization from ..authorization import CoinJoinAuthorization
# Setting nSequence to this value for every input in a transaction disables nLockTime.
_SEQUENCE_FINAL = const(0xFFFFFFFF)
# An Approver object computes the transaction totals and either prompts the user # An Approver object computes the transaction totals and either prompts the user
# to confirm transaction parameters (output addresses, amounts and fees) or uses # to confirm transaction parameters (output addresses, amounts and fees) or uses
@ -27,27 +25,23 @@ _SEQUENCE_FINAL = const(0xFFFFFFFF)
# these parameters to be executed. # these parameters to be executed.
class Approver: class Approver:
def __init__(self, tx: SignTx, coin: CoinInfo) -> None: def __init__(self, tx: SignTx, coin: CoinInfo) -> None:
self.tx = tx
self.coin = coin self.coin = coin
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
self.min_sequence = _SEQUENCE_FINAL # the minimum nSequence of all inputs
# amounts # amounts in the current transaction
self.total_in = 0 # sum of input amounts self.total_in = 0 # sum of input amounts
self.external_in = 0 # sum of external input amounts self.external_in = 0 # sum of external input amounts
self.total_out = 0 # sum of output amounts self.total_out = 0 # sum of output amounts
self.change_out = 0 # change output amount self.change_out = 0 # sum of change output amounts
async def add_internal_input(self, txi: TxInput) -> None: async def add_internal_input(self, txi: TxInput) -> None:
self.weight.add_input(txi) self.weight.add_input(txi)
self.total_in += txi.amount self.total_in += txi.amount
self.min_sequence = min(self.min_sequence, txi.sequence)
def add_external_input(self, txi: TxInput) -> None: def add_external_input(self, txi: TxInput) -> None:
self.weight.add_input(txi) self.weight.add_input(txi)
self.total_in += txi.amount self.total_in += txi.amount
self.external_in += txi.amount self.external_in += txi.amount
self.min_sequence = min(self.min_sequence, txi.sequence)
def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
self.weight.add_output(script_pubkey) self.weight.add_output(script_pubkey)
@ -58,7 +52,7 @@ class Approver:
self.weight.add_output(script_pubkey) self.weight.add_output(script_pubkey)
self.total_out += txo.amount self.total_out += txo.amount
async def approve_tx(self) -> None: async def approve_tx(self, tx_info: TxInfo) -> None:
raise NotImplementedError raise NotImplementedError
@ -84,7 +78,7 @@ class BasicApprover(Approver):
await super().add_external_output(txo, script_pubkey) await super().add_external_output(txo, script_pubkey)
await helpers.confirm_output(txo, self.coin) await helpers.confirm_output(txo, self.coin)
async def approve_tx(self) -> None: async def approve_tx(self, tx_info: TxInfo) -> None:
fee = self.total_in - self.total_out fee = self.total_in - self.total_out
# some coins require negative fees for reward TX # some coins require negative fees for reward TX
@ -103,10 +97,9 @@ class BasicApprover(Approver):
await helpers.confirm_feeoverthreshold(fee, self.coin) await helpers.confirm_feeoverthreshold(fee, self.coin)
if self.change_count > self.MAX_SILENT_CHANGE_COUNT: if self.change_count > self.MAX_SILENT_CHANGE_COUNT:
await helpers.confirm_change_count_over_threshold(self.change_count) await helpers.confirm_change_count_over_threshold(self.change_count)
if self.tx.lock_time > 0: if tx_info.tx.lock_time > 0:
lock_time_disabled = self.min_sequence == _SEQUENCE_FINAL
await helpers.confirm_nondefault_locktime( await helpers.confirm_nondefault_locktime(
self.tx.lock_time, lock_time_disabled tx_info.tx.lock_time, tx_info.lock_time_disabled()
) )
if not self.external_in: if not self.external_in:
await helpers.confirm_total(total, fee, self.coin) await helpers.confirm_total(total, fee, self.coin)
@ -158,7 +151,7 @@ class CoinJoinApprover(Approver):
await super().add_external_output(txo, script_pubkey) await super().add_external_output(txo, script_pubkey)
self._add_output(txo, script_pubkey) self._add_output(txo, script_pubkey)
async def approve_tx(self) -> None: async def approve_tx(self, tx_info: TxInfo) -> None:
# The mining fee of the transaction as a whole. # The mining fee of the transaction as a whole.
mining_fee = self.total_in - self.total_out mining_fee = self.total_in - self.total_out
@ -185,10 +178,10 @@ class CoinJoinApprover(Approver):
if not self.anonymity: if not self.anonymity:
raise wire.ProcessError("No anonymity gain") raise wire.ProcessError("No anonymity gain")
if self.tx.lock_time > 0: if tx_info.tx.lock_time > 0:
raise wire.ProcessError("nLockTime not allowed in CoinJoin") raise wire.ProcessError("nLockTime not allowed in CoinJoin")
if not self.authorization.approve_sign_tx(self.tx, our_fees): if not self.authorization.approve_sign_tx(tx_info.tx, our_fees):
raise wire.ProcessError("Fees exceed authorized limit") raise wire.ProcessError("Fees exceed authorized limit")
# Coordinator fee calculation. # Coordinator fee calculation.

View File

@ -11,12 +11,12 @@ from trezor.utils import HashWriter, ensure
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
from .. import addresses, common, multisig, scripts, writers from .. import addresses, common, multisig, scripts, writers
from ..common import BIP32_WALLET_DEPTH, SIGHASH_ALL, ecdsa_sign from ..common import SIGHASH_ALL, ecdsa_sign, input_is_external, input_is_segwit
from ..ownership import verify_nonownership from ..ownership import verify_nonownership
from ..verification import SignatureVerifier from ..verification import SignatureVerifier
from . import approvers, helpers, progress from . import approvers, helpers, progress
from .hash143 import Hash143 from .hash143 import Hash143
from .matchcheck import MultisigFingerprintChecker, WalletPathChecker from .tx_info import TxInfo
if False: if False:
from typing import List, Optional, Set, Tuple, Union from typing import List, Optional, Set, Tuple, Union
@ -33,28 +33,21 @@ if False:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
# the chain id used for change
_BIP32_CHANGE_CHAIN = const(1)
# the maximum allowed change address. this should be large enough for normal
# use and still allow to quickly brute-force the correct bip32 path
_BIP32_MAX_LAST_ELEMENT = const(1000000)
# the number of bytes to preallocate for serialized transaction chunks # the number of bytes to preallocate for serialized transaction chunks
_MAX_SERIALIZED_CHUNK_SIZE = const(2048) _MAX_SERIALIZED_CHUNK_SIZE = const(2048)
class Bitcoin: class Bitcoin:
async def signer(self) -> None: async def signer(self) -> None:
# Add inputs to hash143 and h_approved and compute the sum of input amounts. # Add inputs to hash143 and h_tx_check and compute the sum of input amounts.
await self.step1_process_inputs() await self.step1_process_inputs()
# Add outputs to hash143 and h_approved, approve outputs and compute # Add outputs to hash143 and h_tx_check, approve outputs and compute
# sum of output amounts. # sum of output amounts.
await self.step2_approve_outputs() await self.step2_approve_outputs()
# Check fee, approve lock_time and total. # Check fee, approve lock_time and total.
await self.approver.approve_tx() await self.approver.approve_tx(self.tx_info)
# Verify the transaction input amounts by requesting each previous transaction # Verify the transaction input amounts by requesting each previous transaction
# and checking its output amount. Verify external inputs which have already # and checking its output amount. Verify external inputs which have already
@ -80,17 +73,11 @@ class Bitcoin:
coin: CoinInfo, coin: CoinInfo,
approver: approvers.Approver, approver: approvers.Approver,
) -> None: ) -> None:
self.tx = helpers.sanitize_sign_tx(tx, coin) self.tx_info = TxInfo(self, helpers.sanitize_sign_tx(tx, coin))
self.keychain = keychain self.keychain = keychain
self.coin = coin self.coin = coin
self.approver = approver self.approver = approver
# checksum of multisig inputs, used to validate change-output
self.multisig_fingerprint = MultisigFingerprintChecker()
# common prefix of input paths, used to validate change-output
self.wallet_path = WalletPathChecker()
# set of indices of inputs which are segwit # set of indices of inputs which are segwit
self.segwit = set() # type: Set[int] self.segwit = set() # type: Set[int]
@ -104,20 +91,7 @@ class Bitcoin:
self.tx_req.serialized = TxRequestSerializedType() self.tx_req.serialized = TxRequestSerializedType()
self.tx_req.serialized.serialized_tx = self.serialized_tx self.tx_req.serialized.serialized_tx = self.serialized_tx
# h_approved is used to make sure that the inputs and outputs streamed for progress.init(tx.inputs_count, tx.outputs_count)
# approval in Steps 1 and 2 are the same as the ones streamed for signing
# legacy inputs in Step 4.
self.h_approved = self.create_hash_writer() # not a real tx hash
# h_inputs is a digest of the inputs streamed for approval in Step 1, which
# is used to ensure that the inputs streamed for verification in Step 3 are
# the same as those in Step 1.
self.h_inputs = None # type: Optional[bytes]
# BIP-0143 transaction hashing
self.hash143 = self.create_hash143()
progress.init(self.tx.inputs_count, self.tx.outputs_count)
def create_hash_writer(self) -> HashWriter: def create_hash_writer(self) -> HashWriter:
return HashWriter(sha256()) return HashWriter(sha256())
@ -126,12 +100,11 @@ class Bitcoin:
return Hash143() return Hash143()
async def step1_process_inputs(self) -> None: async def step1_process_inputs(self) -> None:
for i in range(self.tx.inputs_count): for i in range(self.tx_info.tx.inputs_count):
# STAGE_REQUEST_1_INPUT in legacy # STAGE_REQUEST_1_INPUT in legacy
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.hash143.add_input(txi) # all inputs are included (non-segwit as well) self.tx_info.add_input(txi)
writers.write_tx_input_check(self.h_approved, txi)
if input_is_segwit(txi): if input_is_segwit(txi):
self.segwit.add(i) self.segwit.add(i)
@ -142,10 +115,10 @@ class Bitcoin:
else: else:
await self.process_internal_input(txi) await self.process_internal_input(txi)
self.h_inputs = self.h_approved.get_digest() self.tx_info.h_inputs = self.tx_info.h_tx_check.get_digest()
async def step2_approve_outputs(self) -> None: async def step2_approve_outputs(self) -> None:
for i in range(self.tx.outputs_count): for i in range(self.tx_info.tx.outputs_count):
# STAGE_REQUEST_2_OUTPUT in legacy # STAGE_REQUEST_2_OUTPUT in legacy
txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo = await helpers.request_tx_output(self.tx_req, i, self.coin)
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
@ -153,9 +126,9 @@ class Bitcoin:
async def step3_verify_inputs(self) -> None: async def step3_verify_inputs(self) -> None:
# should come out the same as h_inputs, checked before continuing # should come out the same as h_inputs, checked before continuing
h_check = self.create_hash_writer() h_check = HashWriter(sha256())
for i in range(self.tx.inputs_count): for i in range(self.tx_info.tx.inputs_count):
progress.advance() progress.advance()
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
@ -170,14 +143,13 @@ class Bitcoin:
await self.verify_external_input(i, txi, script_pubkey) await self.verify_external_input(i, txi, script_pubkey)
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if h_check.get_digest() != self.h_inputs: if h_check.get_digest() != self.tx_info.h_inputs:
raise wire.ProcessError("Transaction has changed during signing") raise wire.ProcessError("Transaction has changed during signing")
async def step4_serialize_inputs(self) -> None: async def step4_serialize_inputs(self) -> None:
self.write_tx_header(self.serialized_tx, self.tx, bool(self.segwit)) self.write_tx_header(self.serialized_tx, self.tx_info.tx, bool(self.segwit))
write_bitcoin_varint(self.serialized_tx, self.tx.inputs_count) write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.inputs_count)
for i in range(self.tx_info.tx.inputs_count):
for i in range(self.tx.inputs_count):
progress.advance() progress.advance()
if i in self.external: if i in self.external:
await self.serialize_external_input(i) await self.serialize_external_input(i)
@ -187,17 +159,17 @@ class Bitcoin:
await self.sign_nonsegwit_input(i) await self.sign_nonsegwit_input(i)
async def step5_serialize_outputs(self) -> None: async def step5_serialize_outputs(self) -> None:
write_bitcoin_varint(self.serialized_tx, self.tx.outputs_count) write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.outputs_count)
for i in range(self.tx.outputs_count): for i in range(self.tx_info.tx.outputs_count):
progress.advance() progress.advance()
await self.serialize_output(i) await self.serialize_output(i)
async def step6_sign_segwit_inputs(self) -> None: async def step6_sign_segwit_inputs(self) -> None:
if not self.segwit: if not self.segwit:
progress.advance(self.tx.inputs_count) progress.advance(self.tx_info.tx.inputs_count)
return return
for i in range(self.tx.inputs_count): for i in range(self.tx_info.tx.inputs_count):
progress.advance() progress.advance()
if i in self.segwit: if i in self.segwit:
if i in self.external: if i in self.external:
@ -210,13 +182,10 @@ class Bitcoin:
self.serialized_tx.append(0) self.serialized_tx.append(0)
async def step7_finish(self) -> None: async def step7_finish(self) -> None:
self.write_tx_footer(self.serialized_tx, self.tx) self.write_tx_footer(self.serialized_tx, self.tx_info.tx)
await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
async def process_internal_input(self, txi: TxInput) -> None: async def process_internal_input(self, txi: TxInput) -> None:
self.wallet_path.add_input(txi)
self.multisig_fingerprint.add_input(txi)
if txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES: if txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Wrong input script type") raise wire.DataError("Wrong input script type")
@ -226,30 +195,32 @@ class Bitcoin:
self.approver.add_external_input(txi) self.approver.add_external_input(txi)
async def approve_output(self, txo: TxOutput, script_pubkey: bytes) -> None: async def approve_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
if self.output_is_change(txo): if self.tx_info.output_is_change(txo):
# output is change and does not need approval # Output is change and does not need approval.
self.approver.add_change_output(txo, script_pubkey) self.approver.add_change_output(txo, script_pubkey)
else: else:
await self.approver.add_external_output(txo, script_pubkey) await self.approver.add_external_output(txo, script_pubkey)
self.write_tx_output(self.h_approved, txo, script_pubkey) self.tx_info.add_output(txo, script_pubkey)
self.hash143.add_output(txo, script_pubkey)
async def get_tx_digest( async def get_tx_digest(
self, self,
i: int, i: int,
txi: TxInput, txi: TxInput,
tx: Union[SignTx, PrevTx], tx_info: TxInfo,
hash143: Hash143,
h_approved: HashWriter,
public_keys: List[bytes], public_keys: List[bytes],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: Optional[bytes] = None, tx_hash: Optional[bytes] = None,
) -> bytes: ) -> bytes:
if txi.witness: if txi.witness:
return hash143.preimage_hash( return tx_info.hash143.preimage_hash(
txi, public_keys, threshold, tx, self.coin, self.get_sighash_type(txi), txi,
public_keys,
threshold,
tx_info.tx,
self.coin,
self.get_sighash_type(txi),
) )
else: else:
digest, _, _ = await self.get_legacy_tx_digest( digest, _, _ = await self.get_legacy_tx_digest(
@ -279,9 +250,7 @@ class Bitcoin:
tx_digest = await self.get_tx_digest( tx_digest = await self.get_tx_digest(
i, i,
txi, txi,
self.tx, self.tx_info,
self.hash143,
self.h_approved,
verifier.public_keys, verifier.public_keys,
verifier.threshold, verifier.threshold,
script_pubkey, script_pubkey,
@ -301,9 +270,7 @@ class Bitcoin:
if not input_is_segwit(txi): if not input_is_segwit(txi):
raise wire.ProcessError("Transaction has changed during signing") raise wire.ProcessError("Transaction has changed during signing")
self.wallet_path.check_input(txi) self.tx_info.check_input(txi)
# NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
@ -311,8 +278,7 @@ class Bitcoin:
self.write_tx_input(self.serialized_tx, txi, script_sig) self.write_tx_input(self.serialized_tx, txi, script_sig)
def sign_bip143_input(self, txi: TxInput) -> Tuple[bytes, bytes]: def sign_bip143_input(self, txi: TxInput) -> Tuple[bytes, bytes]:
self.wallet_path.check_input(txi) self.tx_info.check_input(txi)
self.multisig_fingerprint.check_input(txi)
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
public_key = node.public_key() public_key = node.public_key()
@ -323,8 +289,13 @@ class Bitcoin:
else: else:
public_keys = [public_key] public_keys = [public_key]
threshold = 1 threshold = 1
hash143_hash = self.hash143.preimage_hash( hash143_hash = self.tx_info.hash143.preimage_hash(
txi, public_keys, threshold, self.tx, self.coin, self.get_sighash_type(txi) txi,
public_keys,
threshold,
self.tx_info.tx,
self.coin,
self.get_sighash_type(txi),
) )
signature = ecdsa_sign(node, hash143_hash) signature = ecdsa_sign(node, hash143_hash)
@ -357,21 +328,20 @@ class Bitcoin:
async def get_legacy_tx_digest( async def get_legacy_tx_digest(
self, self,
index: int, index: int,
tx: Union[SignTx, PrevTx], tx_info: Union[TxInfo, OriginalTxInfo],
h_approved: HashWriter,
script_pubkey: Optional[bytes] = None, script_pubkey: Optional[bytes] = None,
tx_hash: Optional[bytes] = None, tx_hash: Optional[bytes] = None,
) -> Tuple[bytes, TxInput, Optional[bip32.HDNode]]: ) -> Tuple[bytes, TxInput, Optional[bip32.HDNode]]:
# the transaction digest which gets signed for this input # the transaction digest which gets signed for this input
h_sign = self.create_hash_writer() h_sign = self.create_hash_writer()
# should come out the same as h_approved, checked before signing the digest # should come out the same as h_tx_check, checked before signing the digest
h_check = self.create_hash_writer() h_check = HashWriter(sha256())
self.write_tx_header(h_sign, tx, witness_marker=False) self.write_tx_header(h_sign, tx_info.tx, witness_marker=False)
write_bitcoin_varint(h_sign, tx.inputs_count) write_bitcoin_varint(h_sign, tx_info.tx.inputs_count)
for i in range(tx.inputs_count): for i in range(tx_info.tx.inputs_count):
# STAGE_REQUEST_4_INPUT in legacy # STAGE_REQUEST_4_INPUT in legacy
txi = await helpers.request_tx_input(self.tx_req, i, self.coin, tx_hash) txi = await helpers.request_tx_input(self.tx_req, i, self.coin, tx_hash)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
@ -380,9 +350,7 @@ class Bitcoin:
txi_sign = txi txi_sign = txi
node = None node = None
if not script_pubkey: if not script_pubkey:
if isinstance(tx, SignTx): self.tx_info.check_input(txi)
self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi)
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
if txi.multisig: if txi.multisig:
@ -405,29 +373,27 @@ class Bitcoin:
else: else:
self.write_tx_input(h_sign, txi, bytes()) self.write_tx_input(h_sign, txi, bytes())
write_bitcoin_varint(h_sign, tx.outputs_count) write_bitcoin_varint(h_sign, tx_info.tx.outputs_count)
for i in range(tx.outputs_count): for i in range(tx_info.tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT in legacy # STAGE_REQUEST_4_OUTPUT in legacy
txo = await helpers.request_tx_output(self.tx_req, i, self.coin, tx_hash) txo = await helpers.request_tx_output(self.tx_req, i, self.coin, tx_hash)
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
self.write_tx_output(h_check, txo, script_pubkey) self.write_tx_output(h_check, txo, script_pubkey)
self.write_tx_output(h_sign, txo, script_pubkey) self.write_tx_output(h_sign, txo, script_pubkey)
writers.write_uint32(h_sign, tx.lock_time) writers.write_uint32(h_sign, tx_info.tx.lock_time)
writers.write_uint32(h_sign, self.get_sighash_type(txi_sign)) writers.write_uint32(h_sign, self.get_sighash_type(txi_sign))
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if h_approved.get_digest() != h_check.get_digest(): if tx_info.h_tx_check.get_digest() != h_check.get_digest():
raise wire.ProcessError("Transaction has changed during signing") raise wire.ProcessError("Transaction has changed during signing")
tx_digest = writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double) tx_digest = writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double)
return tx_digest, txi_sign, node return tx_digest, txi_sign, node
async def sign_nonsegwit_input(self, i: int) -> None: async def sign_nonsegwit_input(self, i: int) -> None:
tx_digest, txi, node = await self.get_legacy_tx_digest( tx_digest, txi, node = await self.get_legacy_tx_digest(i, self.tx_info)
i, self.tx, self.h_approved
)
assert node is not None assert node is not None
# compute the signature from the tx digest # compute the signature from the tx digest
@ -576,19 +542,6 @@ class Bitcoin:
return scripts.output_derive_script(txo.address, self.coin) return scripts.output_derive_script(txo.address, self.coin)
def output_is_change(self, txo: TxOutput) -> bool:
if txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES:
return False
if txo.multisig and not self.multisig_fingerprint.output_matches(txo):
return False
return (
self.wallet_path.output_matches(txo)
and len(txo.address_n) >= BIP32_WALLET_DEPTH
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
and txo.amount > 0
)
# Tx Inputs # Tx Inputs
# === # ===
@ -603,19 +556,3 @@ class Bitcoin:
pubkey, pubkey,
signature, signature,
) )
def input_is_segwit(txi: TxInput) -> bool:
return txi.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES or (
txi.script_type == InputScriptType.EXTERNAL and txi.witness is not None
)
def input_is_nonsegwit(txi: TxInput) -> bool:
return txi.script_type in common.NONSEGWIT_INPUT_SCRIPT_TYPES or (
txi.script_type == InputScriptType.EXTERNAL and txi.witness is None
)
def input_is_external(txi: TxInput) -> bool:
return txi.script_type == InputScriptType.EXTERNAL

View File

@ -8,12 +8,13 @@ from trezor.messages.TxInput import TxInput
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
from .. import multisig, writers from .. import multisig, writers
from ..common import input_is_nonsegwit
from . import helpers from . import helpers
from .bitcoin import Bitcoin, Hash143, input_is_nonsegwit from .bitcoin import Bitcoin
if False: if False:
from typing import List, Optional, Union from typing import List, Optional, Union
from trezor.utils import HashWriter from .tx_info import OriginalTxInfo, TxInfo
_SIGHASH_FORKID = const(0x40) _SIGHASH_FORKID = const(0x40)
@ -45,21 +46,24 @@ class Bitcoinlike(Bitcoin):
self, self,
i: int, i: int,
txi: TxInput, txi: TxInput,
tx: Union[SignTx, PrevTx], tx_info: Union[TxInfo, OriginalTxInfo],
hash143: Hash143,
h_approved: HashWriter,
public_keys: List[bytes], public_keys: List[bytes],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: Optional[bytes] = None, tx_hash: Optional[bytes] = None,
) -> bytes: ) -> bytes:
if self.coin.force_bip143: if self.coin.force_bip143:
return hash143.preimage_hash( return tx_info.hash143.preimage_hash(
txi, public_keys, threshold, tx, self.coin, self.get_sighash_type(txi), txi,
public_keys,
threshold,
tx_info.tx,
self.coin,
self.get_sighash_type(txi),
) )
else: else:
return await super().get_tx_digest( return await super().get_tx_digest(
i, txi, tx, hash143, h_approved, public_keys, threshold, script_pubkey i, txi, tx_info, public_keys, threshold, script_pubkey
) )
def get_sighash_type(self, txi: TxInput) -> int: def get_sighash_type(self, txi: TxInput) -> int:

View File

@ -22,7 +22,7 @@ DECRED_SCRIPT_VERSION = const(0)
DECRED_SIGHASH_ALL = const(1) DECRED_SIGHASH_ALL = const(1)
if False: if False:
from typing import Union from typing import Optional, Union
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInput import TxInput from trezor.messages.TxInput import TxInput
@ -54,15 +54,17 @@ class Decred(Bitcoin):
approver: approvers.Approver, approver: approvers.Approver,
) -> None: ) -> None:
ensure(coin.decred) ensure(coin.decred)
self.h_prefix = HashWriter(blake256()) self.h_prefix = HashWriter(blake256())
writers.write_uint32(self.h_prefix, tx.version | DECRED_SERIALIZE_NO_WITNESS)
write_bitcoin_varint(self.h_prefix, tx.inputs_count)
super().__init__(tx, keychain, coin, approver) super().__init__(tx, keychain, coin, approver)
self.write_tx_header(self.serialized_tx, self.tx, witness_marker=True) self.write_tx_header(self.serialized_tx, self.tx_info.tx, witness_marker=True)
write_bitcoin_varint(self.serialized_tx, self.tx.inputs_count) write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.inputs_count)
writers.write_uint32(
self.h_prefix, self.tx_info.tx.version | DECRED_SERIALIZE_NO_WITNESS
)
write_bitcoin_varint(self.h_prefix, self.tx_info.tx.inputs_count)
def create_hash_writer(self) -> HashWriter: def create_hash_writer(self) -> HashWriter:
return HashWriter(blake256()) return HashWriter(blake256())
@ -71,11 +73,11 @@ class Decred(Bitcoin):
return DecredHash(self.h_prefix) return DecredHash(self.h_prefix)
async def step2_approve_outputs(self) -> None: async def step2_approve_outputs(self) -> None:
write_bitcoin_varint(self.serialized_tx, self.tx.outputs_count) write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.outputs_count)
write_bitcoin_varint(self.h_prefix, self.tx.outputs_count) write_bitcoin_varint(self.h_prefix, self.tx_info.tx.outputs_count)
await super().step2_approve_outputs() await super().step2_approve_outputs()
self.write_tx_footer(self.serialized_tx, self.tx) self.write_tx_footer(self.serialized_tx, self.tx_info.tx)
self.write_tx_footer(self.h_prefix, self.tx) self.write_tx_footer(self.h_prefix, self.tx_info.tx)
async def process_internal_input(self, txi: TxInput) -> None: async def process_internal_input(self, txi: TxInput) -> None:
await super().process_internal_input(txi) await super().process_internal_input(txi)
@ -86,22 +88,26 @@ class Decred(Bitcoin):
async def process_external_input(self, txi: TxInput) -> None: async def process_external_input(self, txi: TxInput) -> None:
raise wire.DataError("External inputs not supported") raise wire.DataError("External inputs not supported")
async def approve_output(self, txo: TxOutput, script_pubkey: bytes) -> None: async def approve_output(
await super().approve_output(txo, script_pubkey) self,
txo: TxOutput,
script_pubkey: bytes,
orig_txo: Optional[TxOutput],
) -> None:
await super().approve_output(txo, script_pubkey, orig_txo)
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
async def step4_serialize_inputs(self) -> None: async def step4_serialize_inputs(self) -> None:
write_bitcoin_varint(self.serialized_tx, self.tx.inputs_count) write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.inputs_count)
prefix_hash = self.h_prefix.get_digest() prefix_hash = self.h_prefix.get_digest()
for i_sign in range(self.tx.inputs_count): for i_sign in range(self.tx_info.tx.inputs_count):
progress.advance() progress.advance()
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.wallet_path.check_input(txi_sign) self.tx_info.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
key_sign = self.keychain.derive(txi_sign.address_n) key_sign = self.keychain.derive(txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
@ -121,11 +127,11 @@ class Decred(Bitcoin):
h_witness = self.create_hash_writer() h_witness = self.create_hash_writer()
writers.write_uint32( writers.write_uint32(
h_witness, self.tx.version | DECRED_SERIALIZE_WITNESS_SIGNING h_witness, self.tx_info.tx.version | DECRED_SERIALIZE_WITNESS_SIGNING
) )
write_bitcoin_varint(h_witness, self.tx.inputs_count) write_bitcoin_varint(h_witness, self.tx_info.tx.inputs_count)
for ii in range(self.tx.inputs_count): for ii in range(self.tx_info.tx.inputs_count):
if ii == i_sign: if ii == i_sign:
writers.write_bytes_prefixed(h_witness, prev_pkscript) writers.write_bytes_prefixed(h_witness, prev_pkscript)
else: else:

View File

@ -0,0 +1,135 @@
from micropython import const
from trezor import wire
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .. import common, writers
from ..common import BIP32_WALLET_DEPTH, input_is_external
from .matchcheck import MultisigFingerprintChecker, WalletPathChecker
if False:
from typing import Optional, Protocol, Union
from trezor.messages.SignTx import SignTx
from trezor.messages.PrevTx import PrevTx
from trezor.messages.TxInput import TxInput
from trezor.messages.TxOutput import TxOutput
from trezor.messages.PrevInput import PrevInput
from trezor.messages.PrevOutput import PrevOutput
from .hash143 import Hash143
from apps.common.coininfo import CoinInfo
class Signer(Protocol):
coin = ... # type: CoinInfo
def create_hash_writer(self) -> HashWriter:
...
def create_hash143(self) -> Hash143:
...
def write_tx_header(
self,
w: writers.Writer,
tx: Union[SignTx, PrevTx],
witness_marker: bool,
) -> None:
...
@staticmethod
def write_tx_input(
w: writers.Writer,
txi: Union[TxInput, PrevInput],
script: bytes,
) -> None:
...
@staticmethod
def write_tx_output(
w: writers.Writer,
txo: Union[TxOutput, PrevOutput],
script_pubkey: bytes,
) -> None:
...
async def write_prev_tx_footer(
self, w: writers.Writer, tx: PrevTx, prev_hash: bytes
) -> None:
...
# The chain id used for change.
_BIP32_CHANGE_CHAIN = const(1)
# The maximum allowed change address. This should be large enough for normal
# use and still allow to quickly brute-force the correct BIP32 path.
_BIP32_MAX_LAST_ELEMENT = const(1000000)
# Setting nSequence to this value for every input in a transaction disables nLockTime.
_SEQUENCE_FINAL = const(0xFFFFFFFF)
class TxInfoBase:
def __init__(self, signer: Signer) -> None:
# Checksum of multisig inputs, used to validate change-output.
self.multisig_fingerprint = MultisigFingerprintChecker()
# Common prefix of input paths, used to validate change-output.
self.wallet_path = WalletPathChecker()
# h_tx_check is used to make sure that the inputs and outputs streamed in
# different steps are the same every time, e.g. the ones streamed for approval
# in Steps 1 and 2 and the ones streamed for signing legacy inputs in Step 4.
self.h_tx_check = HashWriter(sha256()) # not a real tx hash
# BIP-0143 transaction hashing.
self.hash143 = signer.create_hash143()
# The minimum nSequence of all inputs.
self.min_sequence = _SEQUENCE_FINAL
def add_input(self, txi: TxInput) -> None:
self.hash143.add_input(txi) # all inputs are included (non-segwit as well)
writers.write_tx_input_check(self.h_tx_check, txi)
self.min_sequence = min(self.min_sequence, txi.sequence)
if not input_is_external(txi):
self.wallet_path.add_input(txi)
self.multisig_fingerprint.add_input(txi)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
self.hash143.add_output(txo, script_pubkey)
writers.write_tx_output(self.h_tx_check, txo, script_pubkey)
def check_input(self, txi: TxInput) -> None:
self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi)
def output_is_change(self, txo: TxOutput) -> bool:
if txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES:
return False
if txo.multisig and not self.multisig_fingerprint.output_matches(txo):
return False
return (
self.wallet_path.output_matches(txo)
and len(txo.address_n) >= BIP32_WALLET_DEPTH
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
and txo.amount > 0
)
def lock_time_disabled(self) -> bool:
return self.min_sequence == _SEQUENCE_FINAL
# Used to keep track of the transaction currently being signed.
class TxInfo(TxInfoBase):
def __init__(self, signer: Signer, tx: SignTx) -> None:
super().__init__(signer)
self.tx = tx
# h_inputs is a digest of the inputs streamed for approval in Step 1, which
# is used to ensure that the inputs streamed for verification in Step 3 are
# the same as those in Step 1.
self.h_inputs = None # type: Optional[bytes]

View File

@ -31,6 +31,7 @@ from .hash143 import Hash143
if False: if False:
from apps.common import coininfo from apps.common import coininfo
from typing import List, Optional, Union from typing import List, Optional, Union
from .tx_info import OriginalTxInfo, TxInfo
from ..writers import Writer from ..writers import Writer
OVERWINTERED = const(0x80000000) OVERWINTERED = const(0x80000000)
@ -111,14 +112,14 @@ class Zcashlike(Bitcoinlike):
ensure(coin.overwintered) ensure(coin.overwintered)
super().__init__(tx, keychain, coin, approver) super().__init__(tx, keychain, coin, approver)
if self.tx.version != 4: if tx.version != 4:
raise wire.DataError("Unsupported transaction version.") raise wire.DataError("Unsupported transaction version.")
def create_hash143(self) -> Hash143: def create_hash143(self) -> Hash143:
return Zip243Hash() return Zip243Hash()
async def step7_finish(self) -> None: async def step7_finish(self) -> None:
self.write_tx_footer(self.serialized_tx, self.tx) self.write_tx_footer(self.serialized_tx, self.tx_info.tx)
write_uint64(self.serialized_tx, 0) # valueBalance write_uint64(self.serialized_tx, 0) # valueBalance
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend
@ -134,16 +135,19 @@ class Zcashlike(Bitcoinlike):
self, self,
i: int, i: int,
txi: TxInput, txi: TxInput,
tx: Union[SignTx, PrevTx], tx_info: Union[TxInfo, OriginalTxInfo],
hash143: Hash143,
h_approved: HashWriter,
public_keys: List[bytes], public_keys: List[bytes],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: Optional[bytes] = None, tx_hash: Optional[bytes] = None,
) -> bytes: ) -> bytes:
return hash143.preimage_hash( return tx_info.hash143.preimage_hash(
txi, public_keys, threshold, tx, self.coin, self.get_sighash_type(txi) txi,
public_keys,
threshold,
tx_info.tx,
self.coin,
self.get_sighash_type(txi),
) )
def write_tx_header( def write_tx_header(
@ -165,9 +169,6 @@ class Zcashlike(Bitcoinlike):
if tx.version >= 3: if tx.version >= 3:
write_uint32(w, tx.expiry) # expiryHeight write_uint32(w, tx.expiry) # expiryHeight
assert self.tx.version_group_id is not None
assert self.tx.expiry is not None
def derive_script_code( def derive_script_code(
txi: TxInput, public_keys: List[bytes], threshold: int, coin: CoinInfo txi: TxInput, public_keys: List[bytes], threshold: int, coin: CoinInfo

View File

@ -10,6 +10,8 @@ from trezor.messages import InputScriptType, OutputScriptType
from apps.common import coins from apps.common import coins
from apps.bitcoin.authorization import CoinJoinAuthorization from apps.bitcoin.authorization import CoinJoinAuthorization
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
from apps.bitcoin.sign_tx.bitcoin import Bitcoin
from apps.bitcoin.sign_tx.tx_info import TxInfo
class TestApprover(unittest.TestCase): class TestApprover(unittest.TestCase):
@ -104,6 +106,7 @@ class TestApprover(unittest.TestCase):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0) tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization) approver = CoinJoinApprover(tx, self.coin, authorization)
signer = Bitcoin(tx, None, self.coin, approver)
for txi in inputs: for txi in inputs:
if txi.script_type == InputScriptType.EXTERNAL: if txi.script_type == InputScriptType.EXTERNAL:
@ -117,7 +120,7 @@ class TestApprover(unittest.TestCase):
else: else:
await_result(approver.add_external_output(txo, script_pubkey=bytes(22))) await_result(approver.add_external_output(txo, script_pubkey=bytes(22)))
await_result(approver.approve_tx()) await_result(approver.approve_tx(TxInfo(signer, tx)))
def test_coinjoin_input_account_depth_mismatch(self): def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)