1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 07:50:57 +00:00

feat(core): Implement the BIP-341 common signature message computation.

This commit is contained in:
Andrew Kozlik 2021-10-29 21:39:36 +02:00 committed by Andrew Kozlik
parent 99e4ed6f42
commit 381e8bc85a
9 changed files with 96 additions and 10 deletions

View File

@ -3,8 +3,9 @@ from micropython import const
from trezor import wire from trezor import wire
from trezor.crypto import bech32, bip32, der from trezor.crypto import bech32, bip32, der
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType, OutputScriptType
from trezor.utils import ensure from trezor.utils import HashWriter, ensure
if False: if False:
from typing import Tuple from typing import Tuple
@ -124,3 +125,10 @@ def input_is_taproot(txi: TxInput) -> bool:
def input_is_external(txi: TxInput) -> bool: def input_is_external(txi: TxInput) -> bool:
return txi.script_type == InputScriptType.EXTERNAL return txi.script_type == InputScriptType.EXTERNAL
def tagged_hashwriter(tag: bytes) -> HashWriter:
tag_digest = sha256(tag).digest()
ctx = sha256(tag_digest)
ctx.update(tag_digest)
return HashWriter(ctx)

View File

@ -403,6 +403,7 @@ class Bitcoin:
) -> bytes: ) -> bytes:
if txi.witness: if txi.witness:
return tx_info.hash143.preimage_hash( return tx_info.hash143.preimage_hash(
i,
txi, txi,
public_keys, public_keys,
threshold, threshold,
@ -475,6 +476,7 @@ class Bitcoin:
public_keys = [public_key] public_keys = [public_key]
threshold = 1 threshold = 1
hash143_hash = self.tx_info.hash143.preimage_hash( hash143_hash = self.tx_info.hash143.preimage_hash(
0,
txi, txi,
public_keys, public_keys,
threshold, threshold,

View File

@ -47,10 +47,10 @@ class Bitcoinlike(Bitcoin):
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: bytes | None = None,
) -> bytes: ) -> bytes:
if self.coin.force_bip143: if self.coin.force_bip143:
return tx_info.hash143.preimage_hash( return tx_info.hash143.preimage_hash(
i,
txi, txi,
public_keys, public_keys,
threshold, threshold,

View File

@ -62,6 +62,7 @@ class DecredHash:
def preimage_hash( def preimage_hash(
self, self,
i: int,
txi: TxInput, txi: TxInput,
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,

View File

@ -19,6 +19,7 @@ if False:
def preimage_hash( def preimage_hash(
self, self,
i: int,
txi: TxInput, txi: TxInput,
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
@ -33,7 +34,9 @@ if False:
class Bip143Hash: class Bip143Hash:
def __init__(self) -> None: def __init__(self) -> None:
self.h_prevouts = HashWriter(sha256()) self.h_prevouts = HashWriter(sha256())
self.h_sequence = HashWriter(sha256()) self.h_amounts = HashWriter(sha256())
self.h_scriptpubkeys = HashWriter(sha256())
self.h_sequences = HashWriter(sha256())
self.h_outputs = HashWriter(sha256()) self.h_outputs = HashWriter(sha256())
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None: def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
@ -41,12 +44,29 @@ class Bip143Hash:
self.h_prevouts, txi.prev_hash, writers.TX_HASH_SIZE self.h_prevouts, txi.prev_hash, writers.TX_HASH_SIZE
) )
writers.write_uint32(self.h_prevouts, txi.prev_index) writers.write_uint32(self.h_prevouts, txi.prev_index)
writers.write_uint32(self.h_sequence, txi.sequence) writers.write_uint64(self.h_amounts, txi.amount)
writers.write_bytes_prefixed(self.h_scriptpubkeys, script_pubkey)
writers.write_uint32(self.h_sequences, txi.sequence)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
writers.write_tx_output(self.h_outputs, txo, script_pubkey) writers.write_tx_output(self.h_outputs, txo, script_pubkey)
def preimage_hash( def preimage_hash(
self,
i: int,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: coininfo.CoinInfo,
sighash_type: int,
) -> bytes:
if input_is_taproot(txi):
return self.bip341_hash(i, tx, sighash_type)
else:
return self.bip143_hash(txi, public_keys, threshold, tx, coin, sighash_type)
def bip143_hash(
self, self,
txi: TxInput, txi: TxInput,
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
@ -68,7 +88,7 @@ class Bip143Hash:
# hashSequence # hashSequence
sequence_hash = writers.get_tx_hash( sequence_hash = writers.get_tx_hash(
self.h_sequence, double=coin.sign_hash_double self.h_sequences, double=coin.sign_hash_double
) )
writers.write_bytes_fixed(h_preimage, sequence_hash, writers.TX_HASH_SIZE) writers.write_bytes_fixed(h_preimage, sequence_hash, writers.TX_HASH_SIZE)
@ -98,3 +118,56 @@ class Bip143Hash:
writers.write_uint32(h_preimage, sighash_type) writers.write_uint32(h_preimage, sighash_type)
return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double) return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double)
def bip341_hash(
self,
i: int,
tx: SignTx | PrevTx,
sighash_type: int,
) -> bytes:
h_sigmsg = tagged_hashwriter(b"TapSighash")
# sighash epoch 0
writers.write_uint8(h_sigmsg, 0)
# nHashType
writers.write_uint8(h_sigmsg, sighash_type & 0xFF)
# nVersion
writers.write_uint32(h_sigmsg, tx.version)
# nLockTime
writers.write_uint32(h_sigmsg, tx.lock_time)
# sha_prevouts
writers.write_bytes_fixed(
h_sigmsg, self.h_prevouts.get_digest(), writers.TX_HASH_SIZE
)
# sha_amounts
writers.write_bytes_fixed(
h_sigmsg, self.h_amounts.get_digest(), writers.TX_HASH_SIZE
)
# sha_scriptpubkeys
writers.write_bytes_fixed(
h_sigmsg, self.h_scriptpubkeys.get_digest(), writers.TX_HASH_SIZE
)
# sha_sequences
writers.write_bytes_fixed(
h_sigmsg, self.h_sequences.get_digest(), writers.TX_HASH_SIZE
)
# sha_outputs
writers.write_bytes_fixed(
h_sigmsg, self.h_outputs.get_digest(), writers.TX_HASH_SIZE
)
# spend_type 0 (no tapscript message extension, no annex)
writers.write_uint8(h_sigmsg, 0)
# input_index
writers.write_uint32(h_sigmsg, i)
return h_sigmsg.get_digest()

View File

@ -49,6 +49,7 @@ class Zip243Hash:
def preimage_hash( def preimage_hash(
self, self,
i: int,
txi: TxInput, txi: TxInput,
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
@ -145,6 +146,7 @@ class Zcashlike(Bitcoinlike):
tx_hash: bytes | None = None, tx_hash: bytes | None = None,
) -> bytes: ) -> bytes:
return tx_info.hash143.preimage_hash( return tx_info.hash143.preimage_hash(
0,
txi, txi,
public_keys, public_keys,
threshold, threshold,

View File

@ -60,7 +60,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
bip143 = Bip143Hash() bip143 = Bip143Hash()
bip143.add_input(self.inp1, b"") bip143.add_input(self.inp1, b"")
bip143.add_input(self.inp2, b"") bip143.add_input(self.inp2, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double) sequence_hash = get_tx_hash(bip143.h_sequences, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'52b0a642eea2fb7ae638c36f6252b6750293dbe574a806984b8e4d8548339a3b') self.assertEqual(hexlify(sequence_hash), b'52b0a642eea2fb7ae638c36f6252b6750293dbe574a806984b8e4d8548339a3b')
def test_outputs(self): def test_outputs(self):
@ -95,7 +95,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
# test data public key hash # test data public key hash
# only for input 2 - input 1 is not segwit # only for input 2 - input 1 is not segwit
result = bip143.preimage_hash(self.inp2, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL) result = bip143.preimage_hash(1, self.inp2, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
self.assertEqual(hexlify(result), b'2fa3f1351618b2532228d7182d3221d95c21fd3d496e7e22e9ded873cf022a8b') self.assertEqual(hexlify(result), b'2fa3f1351618b2532228d7182d3221d95c21fd3d496e7e22e9ded873cf022a8b')

View File

@ -50,7 +50,7 @@ class TestSegwitBip143(unittest.TestCase):
coin = coins.by_name(self.tx.coin_name) coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash() bip143 = Bip143Hash()
bip143.add_input(self.inp1, b"") bip143.add_input(self.inp1, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double) sequence_hash = get_tx_hash(bip143.h_sequences, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'18606b350cd8bf565266bc352f0caddcf01e8fa789dd8a15386327cf8cabe198') self.assertEqual(hexlify(sequence_hash), b'18606b350cd8bf565266bc352f0caddcf01e8fa789dd8a15386327cf8cabe198')
def test_bip143_outputs(self): def test_bip143_outputs(self):
@ -80,7 +80,7 @@ class TestSegwitBip143(unittest.TestCase):
node = keychain.derive(self.inp1.address_n) node = keychain.derive(self.inp1.address_n)
# test data public key hash # test data public key hash
result = bip143.preimage_hash(self.inp1, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL) result = bip143.preimage_hash(0, self.inp1, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL)
self.assertEqual(hexlify(result), b'6e28aca7041720995d4acf59bbda64eef5d6f23723d23f2e994757546674bbd9') self.assertEqual(hexlify(result), b'6e28aca7041720995d4acf59bbda64eef5d6f23723d23f2e994757546674bbd9')

View File

@ -213,7 +213,7 @@ class TestZcashZip243(unittest.TestCase):
self.assertEqual(hexlify(get_tx_hash(zip243.h_prevouts)), v["prevouts_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_prevouts)), v["prevouts_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_sequence)), v["sequence_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_sequence)), v["sequence_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_outputs)), v["outputs_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_outputs)), v["outputs_hash"])
self.assertEqual(hexlify(zip243.preimage_hash(txi, [unhexlify(i["pubkey"])], 1, tx, coin, SIGHASH_ALL)), v["preimage_hash"]) self.assertEqual(hexlify(zip243.preimage_hash(0, txi, [unhexlify(i["pubkey"])], 1, tx, coin, SIGHASH_ALL)), v["preimage_hash"])
if __name__ == "__main__": if __name__ == "__main__":