diff --git a/src/apps/common/__init__.py b/src/apps/common/__init__.py index 52ed864f39..7be160a222 100644 --- a/src/apps/common/__init__.py +++ b/src/apps/common/__init__.py @@ -1,4 +1,3 @@ from micropython import const HARDENED = const(0x80000000) -OVERWINTERED = const(0x80000000) diff --git a/src/apps/wallet/sign_tx/overwinter_zip143.py b/src/apps/wallet/sign_tx/overwinter_zip143.py new file mode 100644 index 0000000000..81eeb88e5a --- /dev/null +++ b/src/apps/wallet/sign_tx/overwinter_zip143.py @@ -0,0 +1,90 @@ +from micropython import const + +from trezor.crypto.hashlib import blake2b +from trezor.messages.SignTx import SignTx +from trezor.messages.TxInputType import TxInputType +from trezor.messages.TxOutputBinType import TxOutputBinType +from trezor.messages import InputScriptType, FailureType +from trezor.utils import HashWriter + +from apps.common.coininfo import CoinInfo +from apps.wallet.sign_tx.writers import write_bytes, write_bytes_rev, write_uint32, write_uint64, write_varint, write_tx_output, get_tx_hash +from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig +from apps.wallet.sign_tx.multisig import multisig_get_pubkeys + + +OVERWINTERED = const(0x80000000) + + +class Zip143Error(ValueError): + pass + + +class Zip143: + + def __init__(self): + self.h_prevouts = HashWriter(blake2b, b'', 32, b'ZcashPrevoutHash') + self.h_sequence = HashWriter(blake2b, b'', 32, b'ZcashSequencHash') + self.h_outputs = HashWriter(blake2b, b'', 32, b'ZcashOutputsHash') + + def add_prevouts(self, txi: TxInputType): + write_bytes_rev(self.h_prevouts, txi.prev_hash) + write_uint32(self.h_prevouts, txi.prev_index) + + def add_sequence(self, txi: TxInputType): + write_uint32(self.h_sequence, txi.sequence) + + def add_output(self, txo_bin: TxOutputBinType): + write_tx_output(self.h_outputs, txo_bin) + + def get_prevouts_hash(self) -> bytes: + return get_tx_hash(self.h_prevouts, False) + + def get_sequence_hash(self) -> bytes: + return get_tx_hash(self.h_sequence, False) + + def get_outputs_hash(self) -> bytes: + return get_tx_hash(self.h_outputs, False) + + def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: + h_preimage = HashWriter(blake2b, b'', 32, b'ZcashSigHash\x19\x1b\xa8\x5b') # BRANCH_ID = 0x5ba81b19 + + assert tx.overwintered + + write_uint32(h_preimage, tx.version | OVERWINTERED) # 1. nVersion | fOverwintered + write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId + write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # 3. hashPrevouts + write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # 4. hashSequence + write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs + write_bytes(h_preimage, b'\x00' * 32) # 6. hashJoinSplits + write_uint32(h_preimage, tx.lock_time) # 7. nLockTime + write_uint32(h_preimage, tx.expiry) # 8. expiryHeight + write_uint32(h_preimage, sighash) # 9. nHashType + + write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint + write_uint32(h_preimage, txi.prev_index) + + script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode + write_varint(h_preimage, len(script_code)) + write_bytes(h_preimage, script_code) + + write_uint64(h_preimage, txi.amount) # 10c. value + + write_uint32(h_preimage, txi.sequence) # 10d. nSequence + + return get_tx_hash(h_preimage, False) + + # see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification + # item 5 for details + def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray: + + if txi.multisig: + return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m) + + p2pkh = txi.script_type == InputScriptType.SPENDADDRESS + if p2pkh: + return output_script_p2pkh(pubkeyhash) + + else: + raise Zip143Error(FailureType.DataError, + 'Unknown input script type for zip143 script code') diff --git a/src/apps/wallet/sign_tx/segwit_bip143.py b/src/apps/wallet/sign_tx/segwit_bip143.py index 07617229f8..3d8f8383c3 100644 --- a/src/apps/wallet/sign_tx/segwit_bip143.py +++ b/src/apps/wallet/sign_tx/segwit_bip143.py @@ -6,7 +6,6 @@ from trezor.messages import InputScriptType, FailureType from trezor.utils import HashWriter from apps.common.coininfo import CoinInfo -from apps.common import OVERWINTERED from apps.wallet.sign_tx.writers import write_bytes, write_bytes_rev, write_uint32, write_uint64, write_varint, write_tx_output, get_tx_hash from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig from apps.wallet.sign_tx.multisig import multisig_get_pubkeys @@ -45,11 +44,9 @@ class Bip143: def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: h_preimage = HashWriter(sha256) - if tx.overwintered: - write_uint32(h_preimage, tx.version | OVERWINTERED) # nVersion | fOverwintered - write_uint32(h_preimage, coin.version_group_id) # nVersionGroupId - else: - write_uint32(h_preimage, tx.version) # nVersion + assert not tx.overwintered + + write_uint32(h_preimage, tx.version) # nVersion write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # hashSequence @@ -65,9 +62,6 @@ class Bip143: write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs write_uint32(h_preimage, tx.lock_time) # nLockTime - if tx.overwintered: - write_uint32(h_preimage, tx.expiry) # expiryHeight - write_varint(h_preimage, 0) # nJoinSplit write_uint32(h_preimage, sighash) # nHashType return get_tx_hash(h_preimage, True) diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 54b1cbe821..db22584f76 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -9,13 +9,14 @@ from trezor.messages import OutputScriptType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType -from apps.common import address_type, coins, OVERWINTERED +from apps.common import address_type, coins from apps.common.coininfo import CoinInfo from apps.wallet.sign_tx.addresses import * from apps.wallet.sign_tx.helpers import * from apps.wallet.sign_tx.multisig import * from apps.wallet.sign_tx.scripts import * -from apps.wallet.sign_tx.segwit_bip143 import * +from apps.wallet.sign_tx.segwit_bip143 import Bip143, Bip143Error # noqa:F401 +from apps.wallet.sign_tx.overwinter_zip143 import Zip143, Zip143Error, OVERWINTERED # noqa:F401 from apps.wallet.sign_tx.tx_weight_calculator import * from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx import progress @@ -55,7 +56,11 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): # tx, as the SignTx info is streamed only once h_first = HashWriter(sha256) # not a real tx hash - bip143 = Bip143() # bip143 transaction hashing + if tx.overwintered: + hash143 = Zip143() # zip143 transaction hashing + else: + hash143 = Bip143() # bip143 transaction hashing + multifp = MultisigFingerprint() # control checksum of multisig inputs weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) @@ -78,8 +83,8 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): wallet_path = input_extract_wallet_path(txi, wallet_path) write_tx_input_check(h_first, txi) weight.add_input(txi) - bip143.add_prevouts(txi) # all inputs are included (non-segwit as well) - bip143.add_sequence(txi) + hash143.add_prevouts(txi) # all inputs are included (non-segwit as well) + hash143.add_sequence(txi) if txi.multisig: multifp.add(txi.multisig) @@ -101,7 +106,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): if coin.force_bip143 or tx.overwintered: if not txi.amount: raise SigningError(FailureType.DataError, - 'BIP 143 input without amount') + 'BIP/ZIP 143 input without amount') segwit[i] = False segwit_in += txi.amount total_in += txi.amount @@ -129,7 +134,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): 'Output cancelled') write_tx_output(h_first, txo_bin) - bip143.add_output(txo_bin) + hash143.add_output(txo_bin) total_out += txo_bin.amount fee = total_in - total_out @@ -147,7 +152,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): raise SigningError(FailureType.ActionCancelled, 'Total cancelled') - return h_first, bip143, segwit, total_in, wallet_path + return h_first, hash143, segwit, total_in, wallet_path async def sign_tx(tx: SignTx, root: bip32.HDNode): @@ -157,7 +162,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): # Phase 1 - h_first, bip143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) + h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) # Phase 2 # - sign inputs @@ -214,14 +219,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): key_sign = node_derive(root, txi_sign.address_n) key_sign_pub = key_sign.public_key() - bip143_hash = bip143.preimage_hash( + hash143_hash = hash143.preimage_hash( coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) # if multisig, check if singing with a key that is included in multisig if txi_sign.multisig: multisig_pubkey_index(txi_sign.multisig, key_sign_pub) - signature = ecdsa_sign(key_sign, bip143_hash) + signature = ecdsa_sign(key_sign, hash143_hash) tx_ser.signature_index = i_sign tx_ser.signature = signature @@ -357,10 +362,10 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): key_sign = node_derive(root, txi.address_n) key_sign_pub = key_sign.public_key() - bip143_hash = bip143.preimage_hash( + hash143_hash = hash143.preimage_hash( coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) - signature = ecdsa_sign(key_sign, bip143_hash) + signature = ecdsa_sign(key_sign, hash143_hash) if txi.multisig: # find out place of our signature based on the pubkey signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub) diff --git a/src/trezor/utils.py b/src/trezor/utils.py index 7a60fb2d27..d19ac73a09 100644 --- a/src/trezor/utils.py +++ b/src/trezor/utils.py @@ -73,8 +73,8 @@ def format_ordinal(number): class HashWriter: - def __init__(self, hashfunc): - self.ctx = hashfunc() + def __init__(self, hashfunc, *hashargs): + self.ctx = hashfunc(*hashargs) self.buf = bytearray(1) # used in append() def extend(self, buf: bytearray):