1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 23:40:58 +00:00

src/apps/wallet/sign_tx: more changes for zcash overwinter

This commit is contained in:
Pavol Rusnak 2018-06-06 16:53:36 +02:00
parent a3af8faf23
commit a5952d16db
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
5 changed files with 113 additions and 25 deletions

View File

@ -1,4 +1,3 @@
from micropython import const from micropython import const
HARDENED = const(0x80000000) HARDENED = const(0x80000000)
OVERWINTERED = const(0x80000000)

View File

@ -0,0 +1,90 @@
from micropython import const
from trezor.crypto.hashlib import blake2b
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages import InputScriptType, FailureType
from trezor.utils import HashWriter
from apps.common.coininfo import CoinInfo
from apps.wallet.sign_tx.writers import write_bytes, write_bytes_rev, write_uint32, write_uint64, write_varint, write_tx_output, get_tx_hash
from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig
from apps.wallet.sign_tx.multisig import multisig_get_pubkeys
OVERWINTERED = const(0x80000000)
class Zip143Error(ValueError):
pass
class Zip143:
def __init__(self):
self.h_prevouts = HashWriter(blake2b, b'', 32, b'ZcashPrevoutHash')
self.h_sequence = HashWriter(blake2b, b'', 32, b'ZcashSequencHash')
self.h_outputs = HashWriter(blake2b, b'', 32, b'ZcashOutputsHash')
def add_prevouts(self, txi: TxInputType):
write_bytes_rev(self.h_prevouts, txi.prev_hash)
write_uint32(self.h_prevouts, txi.prev_index)
def add_sequence(self, txi: TxInputType):
write_uint32(self.h_sequence, txi.sequence)
def add_output(self, txo_bin: TxOutputBinType):
write_tx_output(self.h_outputs, txo_bin)
def get_prevouts_hash(self) -> bytes:
return get_tx_hash(self.h_prevouts, False)
def get_sequence_hash(self) -> bytes:
return get_tx_hash(self.h_sequence, False)
def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs, False)
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
h_preimage = HashWriter(blake2b, b'', 32, b'ZcashSigHash\x19\x1b\xa8\x5b') # BRANCH_ID = 0x5ba81b19
assert tx.overwintered
write_uint32(h_preimage, tx.version | OVERWINTERED) # 1. nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # 2. nVersionGroupId
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # 3. hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # 4. hashSequence
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # 5. hashOutputs
write_bytes(h_preimage, b'\x00' * 32) # 6. hashJoinSplits
write_uint32(h_preimage, tx.lock_time) # 7. nLockTime
write_uint32(h_preimage, tx.expiry) # 8. expiryHeight
write_uint32(h_preimage, sighash) # 9. nHashType
write_bytes_rev(h_preimage, txi.prev_hash) # 10a. outpoint
write_uint32(h_preimage, txi.prev_index)
script_code = self.derive_script_code(txi, pubkeyhash) # 10b. scriptCode
write_varint(h_preimage, len(script_code))
write_bytes(h_preimage, script_code)
write_uint64(h_preimage, txi.amount) # 10c. value
write_uint32(h_preimage, txi.sequence) # 10d. nSequence
return get_tx_hash(h_preimage, False)
# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
# item 5 for details
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.multisig:
return output_script_multisig(multisig_get_pubkeys(txi.multisig), txi.multisig.m)
p2pkh = txi.script_type == InputScriptType.SPENDADDRESS
if p2pkh:
return output_script_p2pkh(pubkeyhash)
else:
raise Zip143Error(FailureType.DataError,
'Unknown input script type for zip143 script code')

View File

@ -6,7 +6,6 @@ from trezor.messages import InputScriptType, FailureType
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common import OVERWINTERED
from apps.wallet.sign_tx.writers import write_bytes, write_bytes_rev, write_uint32, write_uint64, write_varint, write_tx_output, get_tx_hash from apps.wallet.sign_tx.writers import write_bytes, write_bytes_rev, write_uint32, write_uint64, write_varint, write_tx_output, get_tx_hash
from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig
from apps.wallet.sign_tx.multisig import multisig_get_pubkeys from apps.wallet.sign_tx.multisig import multisig_get_pubkeys
@ -45,10 +44,8 @@ class Bip143:
def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
h_preimage = HashWriter(sha256) h_preimage = HashWriter(sha256)
if tx.overwintered: assert not tx.overwintered
write_uint32(h_preimage, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # nVersionGroupId
else:
write_uint32(h_preimage, tx.version) # nVersion write_uint32(h_preimage, tx.version) # nVersion
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts
@ -65,9 +62,6 @@ class Bip143:
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime write_uint32(h_preimage, tx.lock_time) # nLockTime
if tx.overwintered:
write_uint32(h_preimage, tx.expiry) # expiryHeight
write_varint(h_preimage, 0) # nJoinSplit
write_uint32(h_preimage, sighash) # nHashType write_uint32(h_preimage, sighash) # nHashType
return get_tx_hash(h_preimage, True) return get_tx_hash(h_preimage, True)

View File

@ -9,13 +9,14 @@ from trezor.messages import OutputScriptType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from apps.common import address_type, coins, OVERWINTERED from apps.common import address_type, coins
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.wallet.sign_tx.addresses import * from apps.wallet.sign_tx.addresses import *
from apps.wallet.sign_tx.helpers import * from apps.wallet.sign_tx.helpers import *
from apps.wallet.sign_tx.multisig import * from apps.wallet.sign_tx.multisig import *
from apps.wallet.sign_tx.scripts import * from apps.wallet.sign_tx.scripts import *
from apps.wallet.sign_tx.segwit_bip143 import * from apps.wallet.sign_tx.segwit_bip143 import Bip143, Bip143Error # noqa:F401
from apps.wallet.sign_tx.overwinter_zip143 import Zip143, Zip143Error, OVERWINTERED # noqa:F401
from apps.wallet.sign_tx.tx_weight_calculator import * from apps.wallet.sign_tx.tx_weight_calculator import *
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.writers import *
from apps.wallet.sign_tx import progress from apps.wallet.sign_tx import progress
@ -55,7 +56,11 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# tx, as the SignTx info is streamed only once # tx, as the SignTx info is streamed only once
h_first = HashWriter(sha256) # not a real tx hash h_first = HashWriter(sha256) # not a real tx hash
bip143 = Bip143() # bip143 transaction hashing if tx.overwintered:
hash143 = Zip143() # zip143 transaction hashing
else:
hash143 = Bip143() # bip143 transaction hashing
multifp = MultisigFingerprint() # control checksum of multisig inputs multifp = MultisigFingerprint() # control checksum of multisig inputs
weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count)
@ -78,8 +83,8 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
wallet_path = input_extract_wallet_path(txi, wallet_path) wallet_path = input_extract_wallet_path(txi, wallet_path)
write_tx_input_check(h_first, txi) write_tx_input_check(h_first, txi)
weight.add_input(txi) weight.add_input(txi)
bip143.add_prevouts(txi) # all inputs are included (non-segwit as well) hash143.add_prevouts(txi) # all inputs are included (non-segwit as well)
bip143.add_sequence(txi) hash143.add_sequence(txi)
if txi.multisig: if txi.multisig:
multifp.add(txi.multisig) multifp.add(txi.multisig)
@ -101,7 +106,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
if coin.force_bip143 or tx.overwintered: if coin.force_bip143 or tx.overwintered:
if not txi.amount: if not txi.amount:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'BIP 143 input without amount') 'BIP/ZIP 143 input without amount')
segwit[i] = False segwit[i] = False
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
@ -129,7 +134,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
'Output cancelled') 'Output cancelled')
write_tx_output(h_first, txo_bin) write_tx_output(h_first, txo_bin)
bip143.add_output(txo_bin) hash143.add_output(txo_bin)
total_out += txo_bin.amount total_out += txo_bin.amount
fee = total_in - total_out fee = total_in - total_out
@ -147,7 +152,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled,
'Total cancelled') 'Total cancelled')
return h_first, bip143, segwit, total_in, wallet_path return h_first, hash143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root: bip32.HDNode): async def sign_tx(tx: SignTx, root: bip32.HDNode):
@ -157,7 +162,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# Phase 1 # Phase 1
h_first, bip143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root) h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root)
# Phase 2 # Phase 2
# - sign inputs # - sign inputs
@ -214,14 +219,14 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
key_sign = node_derive(root, txi_sign.address_n) key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
# if multisig, check if singing with a key that is included in multisig # if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig: if txi_sign.multisig:
multisig_pubkey_index(txi_sign.multisig, key_sign_pub) multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
signature = ecdsa_sign(key_sign, bip143_hash) signature = ecdsa_sign(key_sign, hash143_hash)
tx_ser.signature_index = i_sign tx_ser.signature_index = i_sign
tx_ser.signature = signature tx_ser.signature = signature
@ -357,10 +362,10 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
key_sign = node_derive(root, txi.address_n) key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash( hash143_hash = hash143.preimage_hash(
coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
signature = ecdsa_sign(key_sign, bip143_hash) signature = ecdsa_sign(key_sign, hash143_hash)
if txi.multisig: if txi.multisig:
# find out place of our signature based on the pubkey # find out place of our signature based on the pubkey
signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub) signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub)

View File

@ -73,8 +73,8 @@ def format_ordinal(number):
class HashWriter: class HashWriter:
def __init__(self, hashfunc): def __init__(self, hashfunc, *hashargs):
self.ctx = hashfunc() self.ctx = hashfunc(*hashargs)
self.buf = bytearray(1) # used in append() self.buf = bytearray(1) # used in append()
def extend(self, buf: bytearray): def extend(self, buf: bytearray):