feat(core): Implement the BIP-341 common signature message computation.

pull/1918/head
Andrew Kozlik 3 years ago committed by Andrew Kozlik
parent 99e4ed6f42
commit 381e8bc85a

@ -3,8 +3,9 @@ from micropython import const
from trezor import wire
from trezor.crypto import bech32, bip32, der
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType, OutputScriptType
from trezor.utils import ensure
from trezor.utils import HashWriter, ensure
if False:
from typing import Tuple
@ -124,3 +125,10 @@ def input_is_taproot(txi: TxInput) -> bool:
def input_is_external(txi: TxInput) -> bool:
return txi.script_type == InputScriptType.EXTERNAL
def tagged_hashwriter(tag: bytes) -> HashWriter:
tag_digest = sha256(tag).digest()
ctx = sha256(tag_digest)
ctx.update(tag_digest)
return HashWriter(ctx)

@ -403,6 +403,7 @@ class Bitcoin:
) -> bytes:
if txi.witness:
return tx_info.hash143.preimage_hash(
i,
txi,
public_keys,
threshold,
@ -475,6 +476,7 @@ class Bitcoin:
public_keys = [public_key]
threshold = 1
hash143_hash = self.tx_info.hash143.preimage_hash(
0,
txi,
public_keys,
threshold,

@ -47,10 +47,10 @@ class Bitcoinlike(Bitcoin):
public_keys: Sequence[bytes | memoryview],
threshold: int,
script_pubkey: bytes,
tx_hash: bytes | None = None,
) -> bytes:
if self.coin.force_bip143:
return tx_info.hash143.preimage_hash(
i,
txi,
public_keys,
threshold,

@ -62,6 +62,7 @@ class DecredHash:
def preimage_hash(
self,
i: int,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,

@ -19,6 +19,7 @@ if False:
def preimage_hash(
self,
i: int,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
@ -33,7 +34,9 @@ if False:
class Bip143Hash:
def __init__(self) -> None:
self.h_prevouts = HashWriter(sha256())
self.h_sequence = HashWriter(sha256())
self.h_amounts = HashWriter(sha256())
self.h_scriptpubkeys = HashWriter(sha256())
self.h_sequences = HashWriter(sha256())
self.h_outputs = HashWriter(sha256())
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
@ -41,12 +44,29 @@ class Bip143Hash:
self.h_prevouts, txi.prev_hash, writers.TX_HASH_SIZE
)
writers.write_uint32(self.h_prevouts, txi.prev_index)
writers.write_uint32(self.h_sequence, txi.sequence)
writers.write_uint64(self.h_amounts, txi.amount)
writers.write_bytes_prefixed(self.h_scriptpubkeys, script_pubkey)
writers.write_uint32(self.h_sequences, txi.sequence)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
writers.write_tx_output(self.h_outputs, txo, script_pubkey)
def preimage_hash(
self,
i: int,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: coininfo.CoinInfo,
sighash_type: int,
) -> bytes:
if input_is_taproot(txi):
return self.bip341_hash(i, tx, sighash_type)
else:
return self.bip143_hash(txi, public_keys, threshold, tx, coin, sighash_type)
def bip143_hash(
self,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
@ -68,7 +88,7 @@ class Bip143Hash:
# hashSequence
sequence_hash = writers.get_tx_hash(
self.h_sequence, double=coin.sign_hash_double
self.h_sequences, double=coin.sign_hash_double
)
writers.write_bytes_fixed(h_preimage, sequence_hash, writers.TX_HASH_SIZE)
@ -98,3 +118,56 @@ class Bip143Hash:
writers.write_uint32(h_preimage, sighash_type)
return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double)
def bip341_hash(
self,
i: int,
tx: SignTx | PrevTx,
sighash_type: int,
) -> bytes:
h_sigmsg = tagged_hashwriter(b"TapSighash")
# sighash epoch 0
writers.write_uint8(h_sigmsg, 0)
# nHashType
writers.write_uint8(h_sigmsg, sighash_type & 0xFF)
# nVersion
writers.write_uint32(h_sigmsg, tx.version)
# nLockTime
writers.write_uint32(h_sigmsg, tx.lock_time)
# sha_prevouts
writers.write_bytes_fixed(
h_sigmsg, self.h_prevouts.get_digest(), writers.TX_HASH_SIZE
)
# sha_amounts
writers.write_bytes_fixed(
h_sigmsg, self.h_amounts.get_digest(), writers.TX_HASH_SIZE
)
# sha_scriptpubkeys
writers.write_bytes_fixed(
h_sigmsg, self.h_scriptpubkeys.get_digest(), writers.TX_HASH_SIZE
)
# sha_sequences
writers.write_bytes_fixed(
h_sigmsg, self.h_sequences.get_digest(), writers.TX_HASH_SIZE
)
# sha_outputs
writers.write_bytes_fixed(
h_sigmsg, self.h_outputs.get_digest(), writers.TX_HASH_SIZE
)
# spend_type 0 (no tapscript message extension, no annex)
writers.write_uint8(h_sigmsg, 0)
# input_index
writers.write_uint32(h_sigmsg, i)
return h_sigmsg.get_digest()

@ -49,6 +49,7 @@ class Zip243Hash:
def preimage_hash(
self,
i: int,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
@ -145,6 +146,7 @@ class Zcashlike(Bitcoinlike):
tx_hash: bytes | None = None,
) -> bytes:
return tx_info.hash143.preimage_hash(
0,
txi,
public_keys,
threshold,

@ -60,7 +60,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
bip143 = Bip143Hash()
bip143.add_input(self.inp1, b"")
bip143.add_input(self.inp2, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double)
sequence_hash = get_tx_hash(bip143.h_sequences, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'52b0a642eea2fb7ae638c36f6252b6750293dbe574a806984b8e4d8548339a3b')
def test_outputs(self):
@ -95,7 +95,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
# test data public key hash
# only for input 2 - input 1 is not segwit
result = bip143.preimage_hash(self.inp2, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
result = bip143.preimage_hash(1, self.inp2, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
self.assertEqual(hexlify(result), b'2fa3f1351618b2532228d7182d3221d95c21fd3d496e7e22e9ded873cf022a8b')

@ -50,7 +50,7 @@ class TestSegwitBip143(unittest.TestCase):
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double)
sequence_hash = get_tx_hash(bip143.h_sequences, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'18606b350cd8bf565266bc352f0caddcf01e8fa789dd8a15386327cf8cabe198')
def test_bip143_outputs(self):
@ -80,7 +80,7 @@ class TestSegwitBip143(unittest.TestCase):
node = keychain.derive(self.inp1.address_n)
# test data public key hash
result = bip143.preimage_hash(self.inp1, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
result = bip143.preimage_hash(0, self.inp1, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
self.assertEqual(hexlify(result), b'6e28aca7041720995d4acf59bbda64eef5d6f23723d23f2e994757546674bbd9')

@ -213,7 +213,7 @@ class TestZcashZip243(unittest.TestCase):
self.assertEqual(hexlify(get_tx_hash(zip243.h_prevouts)), v["prevouts_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_sequence)), v["sequence_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_outputs)), v["outputs_hash"])
self.assertEqual(hexlify(zip243.preimage_hash(txi, [unhexlify(i["pubkey"])], 1, tx, coin, SIGHASH_ALL)), v["preimage_hash"])
self.assertEqual(hexlify(zip243.preimage_hash(0, txi, [unhexlify(i["pubkey"])], 1, tx, coin, SIGHASH_ALL)), v["preimage_hash"])
if __name__ == "__main__":

Loading…
Cancel
Save