1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 23:38:09 +00:00

chore(core): Convert SIGHASH_* consts to SigHashType enum.

This commit is contained in:
Andrew Kozlik 2021-11-11 13:03:47 +01:00 committed by Andrew Kozlik
parent aaceb5bcc6
commit 221977ad9d
13 changed files with 140 additions and 103 deletions

View File

@ -8,19 +8,43 @@ from trezor.enums import InputScriptType, OutputScriptType
from trezor.utils import HashWriter, ensure from trezor.utils import HashWriter, ensure
if False: if False:
from enum import IntEnum
from typing import Tuple from typing import Tuple
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from trezor.messages import TxInput from trezor.messages import TxInput
else:
IntEnum = object # type: ignore
BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet") BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
# Signature hash type with the same semantics as the SIGHASH_ALL, but instead
# of having to include the byte in the signature, it is implied.
SIGHASH_ALL_TAPROOT = const(0x00)
# Default signature hash type in Bitcoin which signs all inputs and all outputs of the transaction. class SigHashType(IntEnum):
SIGHASH_ALL = const(0x01) """Enumeration type listing the supported signature hash types."""
# Signature hash type with the same semantics as SIGHASH_ALL, but instead
# of having to include the byte in the signature, it is implied.
SIGHASH_ALL_TAPROOT = 0x00
# Default signature hash type in Bitcoin which signs all inputs and all
# outputs of the transaction.
SIGHASH_ALL = 0x01
# Signature hash flag used in some Bitcoin-like altcoins for replay
# protection.
SIGHASH_FORKID = 0x40
# Signature hash type with the same semantics as SIGHASH_ALL. Used in some
# Bitcoin-like altcoins for replay protection.
SIGHASH_ALL_FORKID = 0x41
@classmethod
def from_int(cls, sighash_type: int) -> "SigHashType":
for val in cls.__dict__.values(): # type: SigHashType
if val == sighash_type:
return val
raise ValueError("Unsupported sighash type.")
# The number of bip32 levels used in a wallet (chain and address) # The number of bip32 levels used in a wallet (chain and address)
BIP32_WALLET_DEPTH = const(2) BIP32_WALLET_DEPTH = const(2)

View File

@ -8,6 +8,7 @@ from apps.common.readers import read_bitcoin_varint
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
from . import common from . import common
from .common import SigHashType
from .multisig import ( from .multisig import (
multisig_get_pubkey_count, multisig_get_pubkey_count,
multisig_get_pubkeys, multisig_get_pubkeys,
@ -37,13 +38,13 @@ def write_input_script_prefixed(
script_type: InputScriptType, script_type: InputScriptType,
multisig: MultisigRedeemScriptType | None, multisig: MultisigRedeemScriptType | None,
coin: CoinInfo, coin: CoinInfo,
hash_type: int, sighash_type: SigHashType,
pubkey: bytes, pubkey: bytes,
signature: bytes, signature: bytes,
) -> None: ) -> None:
if script_type == InputScriptType.SPENDADDRESS: if script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
write_input_script_p2pkh_or_p2sh_prefixed(w, pubkey, signature, hash_type) write_input_script_p2pkh_or_p2sh_prefixed(w, pubkey, signature, sighash_type)
elif script_type == InputScriptType.SPENDP2SHWITNESS: elif script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh # p2wpkh or p2wsh using p2sh
@ -69,7 +70,7 @@ def write_input_script_prefixed(
assert multisig is not None # checked in sanitize_tx_input assert multisig is not None # checked in sanitize_tx_input
signature_index = multisig_pubkey_index(multisig, pubkey) signature_index = multisig_pubkey_index(multisig, pubkey)
write_input_script_multisig_prefixed( write_input_script_multisig_prefixed(
w, multisig, signature, signature_index, hash_type, coin w, multisig, signature, signature_index, sighash_type, coin
) )
else: else:
raise wire.ProcessError("Invalid script type") raise wire.ProcessError("Invalid script type")
@ -150,19 +151,21 @@ def write_bip143_script_code_prefixed(
def write_input_script_p2pkh_or_p2sh_prefixed( def write_input_script_p2pkh_or_p2sh_prefixed(
w: Writer, pubkey: bytes, signature: bytes, hash_type: int w: Writer, pubkey: bytes, signature: bytes, sighash_type: SigHashType
) -> None: ) -> None:
write_bitcoin_varint(w, 1 + len(signature) + 1 + 1 + len(pubkey)) write_bitcoin_varint(w, 1 + len(signature) + 1 + 1 + len(pubkey))
append_signature(w, signature, hash_type) append_signature(w, signature, sighash_type)
append_pubkey(w, pubkey) append_pubkey(w, pubkey)
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[memoryview, memoryview, int]: def parse_input_script_p2pkh(
script_sig: bytes,
) -> tuple[memoryview, memoryview, SigHashType]:
try: try:
r = utils.BufferReader(script_sig) r = utils.BufferReader(script_sig)
n = read_op_push(r) n = read_op_push(r)
signature = r.read_memoryview(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() sighash_type = SigHashType.from_int(r.get())
n = read_op_push(r) n = read_op_push(r)
pubkey = r.read_memoryview() pubkey = r.read_memoryview()
@ -171,7 +174,7 @@ def parse_input_script_p2pkh(script_sig: bytes) -> tuple[memoryview, memoryview,
except (ValueError, EOFError): except (ValueError, EOFError):
wire.DataError("Invalid scriptSig.") wire.DataError("Invalid scriptSig.")
return pubkey, signature, hash_type return pubkey, signature, sighash_type
def write_output_script_p2pkh( def write_output_script_p2pkh(
@ -311,14 +314,14 @@ def write_input_script_p2wsh_in_p2sh(
def write_witness_p2wpkh( def write_witness_p2wpkh(
w: Writer, signature: bytes, pubkey: bytes, hash_type: int w: Writer, signature: bytes, pubkey: bytes, sighash_type: SigHashType
) -> None: ) -> None:
write_bitcoin_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2 write_bitcoin_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
write_signature_prefixed(w, signature, hash_type) write_signature_prefixed(w, signature, sighash_type)
write_bytes_prefixed(w, pubkey) write_bytes_prefixed(w, pubkey)
def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]: def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, SigHashType]:
try: try:
r = utils.BufferReader(witness) r = utils.BufferReader(witness)
@ -328,7 +331,7 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]:
n = read_bitcoin_varint(r) n = read_bitcoin_varint(r)
signature = r.read_memoryview(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() sighash_type = SigHashType.from_int(r.get())
pubkey = read_memoryview_prefixed(r) pubkey = read_memoryview_prefixed(r)
if r.remaining_count(): if r.remaining_count():
@ -336,7 +339,7 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]:
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid witness.") raise wire.DataError("Invalid witness.")
return pubkey, signature, hash_type return pubkey, signature, sighash_type
def write_witness_multisig( def write_witness_multisig(
@ -344,7 +347,7 @@ def write_witness_multisig(
multisig: MultisigRedeemScriptType, multisig: MultisigRedeemScriptType,
signature: bytes, signature: bytes,
signature_index: int, signature_index: int,
hash_type: int, sighash_type: SigHashType,
) -> None: ) -> None:
# get other signatures, stretch with empty bytes to the number of the pubkeys # get other signatures, stretch with empty bytes to the number of the pubkeys
signatures = multisig.signatures + [b""] * ( signatures = multisig.signatures + [b""] * (
@ -367,7 +370,7 @@ def write_witness_multisig(
for s in signatures: for s in signatures:
if s: if s:
write_signature_prefixed(w, s, hash_type) # size of the witness included write_signature_prefixed(w, s, sighash_type) # size of the witness included
# redeem script # redeem script
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
@ -376,7 +379,7 @@ def write_witness_multisig(
def parse_witness_multisig( def parse_witness_multisig(
witness: bytes, witness: bytes,
) -> tuple[memoryview, list[tuple[memoryview, int]]]: ) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
try: try:
r = utils.BufferReader(witness) r = utils.BufferReader(witness)
@ -391,8 +394,8 @@ def parse_witness_multisig(
for _ in range(item_count - 2): for _ in range(item_count - 2):
n = read_bitcoin_varint(r) n = read_bitcoin_varint(r)
signature = r.read_memoryview(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() sighash_type = SigHashType.from_int(r.get())
signatures.append((signature, hash_type)) signatures.append((signature, sighash_type))
script = read_memoryview_prefixed(r) script = read_memoryview_prefixed(r)
if r.remaining_count(): if r.remaining_count():
@ -407,13 +410,13 @@ def parse_witness_multisig(
# === # ===
def write_witness_p2tr(w: Writer, signature: bytes, hash_type: int) -> None: def write_witness_p2tr(w: Writer, signature: bytes, sighash_type: SigHashType) -> None:
# Taproot key path spending without annex. # Taproot key path spending without annex.
write_bitcoin_varint(w, 0x01) # num of segwit items write_bitcoin_varint(w, 0x01) # num of segwit items
write_signature_prefixed(w, signature, hash_type) write_signature_prefixed(w, signature, sighash_type)
def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, int]: def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]:
try: try:
r = utils.BufferReader(witness) r = utils.BufferReader(witness)
@ -426,14 +429,17 @@ def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, int]:
raise ValueError raise ValueError
signature = r.read_memoryview(64) signature = r.read_memoryview(64)
hash_type = r.get() if n == 65 else common.SIGHASH_ALL_TAPROOT if n == 65:
sighash_type = SigHashType.from_int(r.get())
else:
sighash_type = SigHashType.SIGHASH_ALL_TAPROOT
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid witness.") raise wire.DataError("Invalid witness.")
return signature, hash_type return signature, sighash_type
# Multisig # Multisig
@ -447,7 +453,7 @@ def write_input_script_multisig_prefixed(
multisig: MultisigRedeemScriptType, multisig: MultisigRedeemScriptType,
signature: bytes, signature: bytes,
signature_index: int, signature_index: int,
hash_type: int, sighash_type: SigHashType,
coin: CoinInfo, coin: CoinInfo,
) -> None: ) -> None:
signatures = multisig.signatures # other signatures signatures = multisig.signatures # other signatures
@ -463,7 +469,7 @@ def write_input_script_multisig_prefixed(
total_length = 1 # OP_FALSE total_length = 1 # OP_FALSE
for s in signatures: for s in signatures:
if s: if s:
total_length += 1 + len(s) + 1 # length, signature, hash_type total_length += 1 + len(s) + 1 # length, signature, sighash_type
total_length += op_push_length(redeem_script_length) + redeem_script_length total_length += op_push_length(redeem_script_length) + redeem_script_length
write_bitcoin_varint(w, total_length) write_bitcoin_varint(w, total_length)
@ -474,7 +480,7 @@ def write_input_script_multisig_prefixed(
for s in signatures: for s in signatures:
if s: if s:
append_signature(w, s, hash_type) append_signature(w, s, sighash_type)
# redeem script # redeem script
write_op_push(w, redeem_script_length) write_op_push(w, redeem_script_length)
@ -483,7 +489,7 @@ def write_input_script_multisig_prefixed(
def parse_input_script_multisig( def parse_input_script_multisig(
script_sig: bytes, script_sig: bytes,
) -> tuple[memoryview, list[tuple[memoryview, int]]]: ) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
try: try:
r = utils.BufferReader(script_sig) r = utils.BufferReader(script_sig)
@ -495,8 +501,8 @@ def parse_input_script_multisig(
n = read_op_push(r) n = read_op_push(r)
while r.remaining_count() > n: while r.remaining_count() > n:
signature = r.read_memoryview(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() sighash_type = SigHashType.from_int(r.get())
signatures.append((signature, hash_type)) signatures.append((signature, sighash_type))
n = read_op_push(r) n = read_op_push(r)
script = r.read_memoryview() script = r.read_memoryview()
@ -600,7 +606,7 @@ def write_bip322_signature_proof(
signature: bytes, signature: bytes,
) -> None: ) -> None:
write_input_script_prefixed( write_input_script_prefixed(
w, script_type, multisig, coin, common.SIGHASH_ALL, public_key, signature w, script_type, multisig, coin, SigHashType.SIGHASH_ALL, public_key, signature
) )
if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES: if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
@ -608,10 +614,10 @@ def write_bip322_signature_proof(
# find the place of our signature based on the public key # find the place of our signature based on the public key
signature_index = multisig_pubkey_index(multisig, public_key) signature_index = multisig_pubkey_index(multisig, public_key)
write_witness_multisig( write_witness_multisig(
w, multisig, signature, signature_index, common.SIGHASH_ALL w, multisig, signature, signature_index, SigHashType.SIGHASH_ALL
) )
else: else:
write_witness_p2wpkh(w, signature, public_key, common.SIGHASH_ALL) write_witness_p2wpkh(w, signature, public_key, SigHashType.SIGHASH_ALL)
else: else:
# Zero entries in witness stack. # Zero entries in witness stack.
w.append(0x00) w.append(0x00)
@ -627,21 +633,23 @@ def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[memoryview, memo
# === # ===
def write_signature_prefixed(w: Writer, signature: bytes, hash_type: int) -> None: def write_signature_prefixed(
w: Writer, signature: bytes, sighash_type: SigHashType
) -> None:
length = len(signature) length = len(signature)
if hash_type != common.SIGHASH_ALL_TAPROOT: if sighash_type != SigHashType.SIGHASH_ALL_TAPROOT:
length += 1 length += 1
write_bitcoin_varint(w, length) write_bitcoin_varint(w, length)
write_bytes_unchecked(w, signature) write_bytes_unchecked(w, signature)
if hash_type != common.SIGHASH_ALL_TAPROOT: if sighash_type != SigHashType.SIGHASH_ALL_TAPROOT:
w.append(hash_type) w.append(sighash_type)
def append_signature(w: Writer, signature: bytes, hash_type: int) -> None: def append_signature(w: Writer, signature: bytes, sighash_type: SigHashType) -> None:
write_op_push(w, len(signature) + 1) write_op_push(w, len(signature) + 1)
write_bytes_unchecked(w, signature) write_bytes_unchecked(w, signature)
w.append(hash_type) w.append(sighash_type)
def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None: def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None:

View File

@ -6,6 +6,7 @@ from trezor.enums import InputScriptType
from apps.common.writers import write_bytes_fixed, write_uint64_le from apps.common.writers import write_bytes_fixed, write_uint64_le
from . import scripts from . import scripts
from .common import SigHashType
from .multisig import multisig_get_pubkeys, multisig_pubkey_index from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .scripts import ( # noqa: F401 from .scripts import ( # noqa: F401
output_script_paytoopreturn, output_script_paytoopreturn,
@ -27,21 +28,21 @@ def write_input_script_prefixed(
script_type: InputScriptType, script_type: InputScriptType,
multisig: MultisigRedeemScriptType | None, multisig: MultisigRedeemScriptType | None,
coin: CoinInfo, coin: CoinInfo,
hash_type: int, sighash_type: SigHashType,
pubkey: bytes, pubkey: bytes,
signature: bytes, signature: bytes,
) -> None: ) -> None:
if script_type == InputScriptType.SPENDADDRESS: if script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
scripts.write_input_script_p2pkh_or_p2sh_prefixed( scripts.write_input_script_p2pkh_or_p2sh_prefixed(
w, pubkey, signature, hash_type w, pubkey, signature, sighash_type
) )
elif script_type == InputScriptType.SPENDMULTISIG: elif script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig # p2sh multisig
assert multisig is not None # checked in sanitize_tx_input assert multisig is not None # checked in sanitize_tx_input
signature_index = multisig_pubkey_index(multisig, pubkey) signature_index = multisig_pubkey_index(multisig, pubkey)
write_input_script_multisig_prefixed( write_input_script_multisig_prefixed(
w, multisig, signature, signature_index, hash_type, coin w, multisig, signature, signature_index, sighash_type, coin
) )
else: else:
raise wire.ProcessError("Invalid script type") raise wire.ProcessError("Invalid script type")
@ -52,7 +53,7 @@ def write_input_script_multisig_prefixed(
multisig: MultisigRedeemScriptType, multisig: MultisigRedeemScriptType,
signature: bytes, signature: bytes,
signature_index: int, signature_index: int,
hash_type: int, sighash_type: SigHashType,
coin: CoinInfo, coin: CoinInfo,
) -> None: ) -> None:
signatures = multisig.signatures # other signatures signatures = multisig.signatures # other signatures
@ -74,7 +75,7 @@ def write_input_script_multisig_prefixed(
for s in signatures: for s in signatures:
if s: if s:
scripts.append_signature(w, s, hash_type) scripts.append_signature(w, s, sighash_type)
# redeem script # redeem script
write_op_push(w, redeem_script_length) write_op_push(w, redeem_script_length)

View File

@ -10,8 +10,7 @@ from apps.common.writers import write_bitcoin_varint
from .. import addresses, common, multisig, scripts, writers from .. import addresses, common, multisig, scripts, writers
from ..common import ( from ..common import (
SIGHASH_ALL, SigHashType,
SIGHASH_ALL_TAPROOT,
bip340_sign, bip340_sign,
ecdsa_sign, ecdsa_sign,
input_is_external, input_is_external,
@ -408,7 +407,9 @@ class Bitcoin:
verifier = SignatureVerifier( verifier = SignatureVerifier(
script_pubkey, txi.script_sig, txi.witness, self.coin script_pubkey, txi.script_sig, txi.witness, self.coin
) )
verifier.ensure_hash_type((SIGHASH_ALL_TAPROOT, self.get_hash_type(txi))) verifier.ensure_hash_type(
(SigHashType.SIGHASH_ALL_TAPROOT, self.get_sighash_type(txi))
)
tx_digest = await self.get_tx_digest( tx_digest = await self.get_tx_digest(
orig.verification_index, orig.verification_index,
txi, txi,
@ -456,7 +457,7 @@ class Bitcoin:
threshold, threshold,
tx_info.tx, tx_info.tx,
self.coin, self.coin,
self.get_sighash_type(txi), self.get_hash_type(txi),
) )
else: else:
digest, _, _ = await self.get_legacy_tx_digest(i, tx_info, script_pubkey) digest, _, _ = await self.get_legacy_tx_digest(i, tx_info, script_pubkey)
@ -479,7 +480,9 @@ class Bitcoin:
script_pubkey, txi.script_sig, txi.witness, self.coin script_pubkey, txi.script_sig, txi.witness, self.coin
) )
verifier.ensure_hash_type((SIGHASH_ALL_TAPROOT, self.get_hash_type(txi))) verifier.ensure_hash_type(
(SigHashType.SIGHASH_ALL_TAPROOT, self.get_sighash_type(txi))
)
tx_digest = await self.get_tx_digest( tx_digest = await self.get_tx_digest(
i, i,
@ -536,7 +539,7 @@ class Bitcoin:
threshold, threshold,
self.tx_info.tx, self.tx_info.tx,
self.coin, self.coin,
self.get_sighash_type(txi), self.get_hash_type(txi),
) )
signature = ecdsa_sign(node, hash143_digest) signature = ecdsa_sign(node, hash143_digest)
@ -563,7 +566,7 @@ class Bitcoin:
if txi.script_type == InputScriptType.SPENDTAPROOT: if txi.script_type == InputScriptType.SPENDTAPROOT:
signature = self.sign_taproot_input(i, txi) signature = self.sign_taproot_input(i, txi)
scripts.write_witness_p2tr( scripts.write_witness_p2tr(
self.serialized_tx, signature, self.get_hash_type(txi) self.serialized_tx, signature, self.get_sighash_type(txi)
) )
else: else:
public_key, signature = self.sign_bip143_input(i, txi) public_key, signature = self.sign_bip143_input(i, txi)
@ -578,11 +581,14 @@ class Bitcoin:
txi.multisig, txi.multisig,
signature, signature,
signature_index, signature_index,
self.get_hash_type(txi), self.get_sighash_type(txi),
) )
else: else:
scripts.write_witness_p2wpkh( scripts.write_witness_p2wpkh(
self.serialized_tx, signature, public_key, self.get_hash_type(txi) self.serialized_tx,
signature,
public_key,
self.get_sighash_type(txi),
) )
self.set_serialized_signature(i, signature) self.set_serialized_signature(i, signature)
@ -645,7 +651,7 @@ class Bitcoin:
self.write_tx_output(h_sign, txo, script_pubkey) self.write_tx_output(h_sign, txo, script_pubkey)
writers.write_uint32(h_sign, tx_info.tx.lock_time) writers.write_uint32(h_sign, tx_info.tx.lock_time)
writers.write_uint32(h_sign, self.get_sighash_type(txi_sign)) writers.write_uint32(h_sign, self.get_hash_type(txi_sign))
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if tx_info.get_tx_check_digest() != h_check.get_digest(): if tx_info.get_tx_check_digest() != h_check.get_digest():
@ -730,18 +736,19 @@ class Bitcoin:
# Tx Helpers # Tx Helpers
# === # ===
def get_sighash_type(self, txi: TxInput) -> int:
if common.input_is_taproot(txi):
return SIGHASH_ALL_TAPROOT
else:
return SIGHASH_ALL
def get_hash_type(self, txi: TxInput) -> int: def get_hash_type(self, txi: TxInput) -> int:
# The nHashType in BIP 143.
if common.input_is_taproot(txi):
return SigHashType.SIGHASH_ALL_TAPROOT
else:
return SigHashType.SIGHASH_ALL
def get_sighash_type(self, txi: TxInput) -> SigHashType:
""" Return the nHashType flags.""" """ Return the nHashType flags."""
# The nHashType is the 8 least significant bits of the sighash type. # The nHashType is the 8 least significant bits of the sighash type.
# Some coins set the 24 most significant bits of the sighash type to # Some coins set the 24 most significant bits of the sighash type to
# the fork ID value. # the fork ID value.
return self.get_sighash_type(txi) & 0xFF return self.get_hash_type(txi) & 0xFF # type: ignore
def write_tx_input_derived( def write_tx_input_derived(
self, self,
@ -757,7 +764,7 @@ class Bitcoin:
txi.script_type, txi.script_type,
txi.multisig, txi.multisig,
self.coin, self.coin,
self.get_hash_type(txi), self.get_sighash_type(txi),
pubkey, pubkey,
signature, signature,
) )

View File

@ -1,12 +1,10 @@
from micropython import const
from trezor import wire from trezor import wire
from trezor.messages import PrevTx, SignTx, TxInput from trezor.messages import PrevTx, SignTx, TxInput
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
from .. import multisig, writers from .. import multisig, writers
from ..common import NONSEGWIT_INPUT_SCRIPT_TYPES from ..common import NONSEGWIT_INPUT_SCRIPT_TYPES, SigHashType
from . import helpers from . import helpers
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
@ -14,8 +12,6 @@ if False:
from typing import Sequence from typing import Sequence
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
_SIGHASH_FORKID = const(0x40)
class Bitcoinlike(Bitcoin): class Bitcoinlike(Bitcoin):
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None: async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
@ -55,17 +51,17 @@ class Bitcoinlike(Bitcoin):
threshold, threshold,
tx_info.tx, tx_info.tx,
self.coin, self.coin,
self.get_sighash_type(txi), self.get_hash_type(txi),
) )
else: else:
return await super().get_tx_digest( return await super().get_tx_digest(
i, txi, tx_info, public_keys, threshold, script_pubkey i, txi, tx_info, public_keys, threshold, script_pubkey
) )
def get_sighash_type(self, txi: TxInput) -> int: def get_hash_type(self, txi: TxInput) -> int:
hashtype = super().get_sighash_type(txi) hashtype = super().get_hash_type(txi)
if self.coin.fork_id is not None: if self.coin.fork_id is not None:
hashtype |= (self.coin.fork_id << 8) | _SIGHASH_FORKID hashtype |= (self.coin.fork_id << 8) | SigHashType.SIGHASH_FORKID
return hashtype return hashtype
def write_tx_header( def write_tx_header(

View File

@ -9,7 +9,7 @@ from trezor.utils import HashWriter, ensure
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
from .. import multisig, scripts_decred, writers from .. import multisig, scripts_decred, writers
from ..common import ecdsa_hash_pubkey, ecdsa_sign from ..common import SigHashType, ecdsa_hash_pubkey, ecdsa_sign
from . import approvers, helpers, progress from . import approvers, helpers, progress
from .approvers import BasicApprover from .approvers import BasicApprover
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
@ -18,7 +18,6 @@ DECRED_SERIALIZE_FULL = const(0 << 16)
DECRED_SERIALIZE_NO_WITNESS = const(1 << 16) DECRED_SERIALIZE_NO_WITNESS = const(1 << 16)
DECRED_SERIALIZE_WITNESS_SIGNING = const(3 << 16) DECRED_SERIALIZE_WITNESS_SIGNING = const(3 << 16)
DECRED_SCRIPT_VERSION = const(0) DECRED_SCRIPT_VERSION = const(0)
DECRED_SIGHASH_ALL = const(1)
OUTPUT_SCRIPT_NULL_SSTXCHANGE = ( OUTPUT_SCRIPT_NULL_SSTXCHANGE = (
b"\xBD\x76\xA9\x14\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x88\xAC" b"\xBD\x76\xA9\x14\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x88\xAC"
) )
@ -67,7 +66,7 @@ class DecredSigHasher:
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: CoinInfo, coin: CoinInfo,
sighash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
raise NotImplementedError raise NotImplementedError
@ -75,7 +74,7 @@ class DecredSigHasher:
self, self,
i: int, i: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
sighash_type: int, sighash_type: SigHashType,
) -> bytes: ) -> bytes:
raise NotImplementedError raise NotImplementedError
@ -197,7 +196,7 @@ class Decred(Bitcoin):
) )
h_sign = self.create_hash_writer() h_sign = self.create_hash_writer()
writers.write_uint32(h_sign, DECRED_SIGHASH_ALL) writers.write_uint32(h_sign, SigHashType.SIGHASH_ALL)
writers.write_bytes_fixed(h_sign, prefix_hash, writers.TX_HASH_SIZE) writers.write_bytes_fixed(h_sign, prefix_hash, writers.TX_HASH_SIZE)
writers.write_bytes_fixed(h_sign, witness_hash, writers.TX_HASH_SIZE) writers.write_bytes_fixed(h_sign, witness_hash, writers.TX_HASH_SIZE)
@ -329,7 +328,7 @@ class Decred(Bitcoin):
txi.script_type, txi.script_type,
txi.multisig, txi.multisig,
self.coin, self.coin,
self.get_hash_type(txi), self.get_sighash_type(txi),
pubkey, pubkey,
signature, signature,
) )

View File

@ -9,6 +9,7 @@ from ..common import tagged_hashwriter
if False: if False:
from typing import Protocol, Sequence from typing import Protocol, Sequence
from ..common import SigHashType
class SigHasher(Protocol): class SigHasher(Protocol):
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None: def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
@ -24,7 +25,7 @@ if False:
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
sighash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
... ...
@ -32,7 +33,7 @@ if False:
self, self,
i: int, i: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
sighash_type: int, sighash_type: SigHashType,
) -> bytes: ) -> bytes:
... ...
@ -65,7 +66,7 @@ class BitcoinSigHasher:
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
sighash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
h_preimage = HashWriter(sha256()) h_preimage = HashWriter(sha256())
@ -107,7 +108,7 @@ class BitcoinSigHasher:
writers.write_uint32(h_preimage, tx.lock_time) writers.write_uint32(h_preimage, tx.lock_time)
# nHashType # nHashType
writers.write_uint32(h_preimage, sighash_type) writers.write_uint32(h_preimage, hash_type)
return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double) return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double)
@ -115,7 +116,7 @@ class BitcoinSigHasher:
self, self,
i: int, i: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
sighash_type: int, sighash_type: SigHashType,
) -> bytes: ) -> bytes:
h_sigmsg = tagged_hashwriter(b"TapSighash") h_sigmsg = tagged_hashwriter(b"TapSighash")

View File

@ -28,6 +28,7 @@ if False:
from apps.common import coininfo from apps.common import coininfo
from .sig_hasher import SigHasher from .sig_hasher import SigHasher
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
from ..common import SigHashType
from ..writers import Writer from ..writers import Writer
OVERWINTERED = const(0x8000_0000) OVERWINTERED = const(0x8000_0000)
@ -54,7 +55,7 @@ class ZcashSigHasher:
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
sighash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
h_preimage = HashWriter( h_preimage = HashWriter(
blake2b( blake2b(
@ -90,7 +91,7 @@ class ZcashSigHasher:
# 11. valueBalance # 11. valueBalance
write_uint64(h_preimage, 0) write_uint64(h_preimage, 0)
# 12. nHashType # 12. nHashType
write_uint32(h_preimage, sighash_type) write_uint32(h_preimage, hash_type)
# 13a. outpoint # 13a. outpoint
write_bytes_reversed(h_preimage, txi.prev_hash, TX_HASH_SIZE) write_bytes_reversed(h_preimage, txi.prev_hash, TX_HASH_SIZE)
write_uint32(h_preimage, txi.prev_index) write_uint32(h_preimage, txi.prev_index)
@ -107,7 +108,7 @@ class ZcashSigHasher:
self, self,
i: int, i: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
sighash_type: int, sighash_type: SigHashType,
) -> bytes: ) -> bytes:
raise NotImplementedError raise NotImplementedError

View File

@ -3,7 +3,7 @@ from trezor.crypto import der
from trezor.crypto.curve import bip340, secp256k1 from trezor.crypto.curve import bip340, secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from .common import OP_0, OP_1, ecdsa_hash_pubkey from .common import OP_0, OP_1, SigHashType, ecdsa_hash_pubkey
from .scripts import ( from .scripts import (
output_script_native_segwit, output_script_native_segwit,
output_script_p2pkh, output_script_p2pkh,
@ -34,7 +34,7 @@ class SignatureVerifier:
): ):
self.threshold = 1 self.threshold = 1
self.public_keys: list[memoryview] = [] self.public_keys: list[memoryview] = []
self.signatures: list[tuple[memoryview, int]] = [] self.signatures: list[tuple[memoryview, SigHashType]] = []
self.is_taproot = False self.is_taproot = False
if not script_sig: if not script_sig:
@ -106,8 +106,8 @@ class SignatureVerifier:
if self.threshold != len(self.signatures): if self.threshold != len(self.signatures):
raise wire.DataError("Invalid signature") raise wire.DataError("Invalid signature")
def ensure_hash_type(self, hash_types: Sequence[int]) -> None: def ensure_hash_type(self, sighash_types: Sequence[SigHashType]) -> None:
if any(h not in hash_types for _, h in self.signatures): if any(h not in sighash_types for _, h in self.signatures):
raise wire.DataError("Unsupported sighash type") raise wire.DataError("Unsupported sighash type")
def verify(self, digest: bytes) -> None: def verify(self, digest: bytes) -> None:

View File

@ -1,6 +1,6 @@
from common import * from common import *
from apps.bitcoin.common import SIGHASH_ALL from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.scripts import output_derive_script
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
@ -95,7 +95,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
# test data public key hash # test data public key hash
# only for input 2 - input 1 is not segwit # only for input 2 - input 1 is not segwit
result = sig_hasher.hash143(self.inp2, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL) result = sig_hasher.hash143(self.inp2, [node.public_key()], 1, self.tx, coin, SigHashType.SIGHASH_ALL)
self.assertEqual(hexlify(result), b'2fa3f1351618b2532228d7182d3221d95c21fd3d496e7e22e9ded873cf022a8b') self.assertEqual(hexlify(result), b'2fa3f1351618b2532228d7182d3221d95c21fd3d496e7e22e9ded873cf022a8b')

View File

@ -1,6 +1,6 @@
from common import * from common import *
from apps.bitcoin.common import SIGHASH_ALL from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.scripts import output_derive_script
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
@ -80,7 +80,7 @@ class TestSegwitBip143(unittest.TestCase):
node = keychain.derive(self.inp1.address_n) node = keychain.derive(self.inp1.address_n)
# test data public key hash # test data public key hash
result = sig_hasher.hash143(self.inp1, [node.public_key()], 1, self.tx, coin, SIGHASH_ALL) result = sig_hasher.hash143(self.inp1, [node.public_key()], 1, self.tx, coin, SigHashType.SIGHASH_ALL)
self.assertEqual(hexlify(result), b'6e28aca7041720995d4acf59bbda64eef5d6f23723d23f2e994757546674bbd9') self.assertEqual(hexlify(result), b'6e28aca7041720995d4acf59bbda64eef5d6f23723d23f2e994757546674bbd9')

View File

@ -1,6 +1,6 @@
from common import * from common import *
from apps.bitcoin.common import SIGHASH_ALL, SIGHASH_ALL_TAPROOT from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.scripts import output_derive_script
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
@ -108,12 +108,12 @@ VECTORS = [
[ [
{ {
"index": 3, "index": 3,
"hash_type": SIGHASH_ALL, "hash_type": SigHashType.SIGHASH_ALL,
"result": unhexlify('6ffd256e108685b41831385f57eebf2fca041bc6b5e607ea11b3e03d4cf9d9ba'), "result": unhexlify('6ffd256e108685b41831385f57eebf2fca041bc6b5e607ea11b3e03d4cf9d9ba'),
}, },
{ {
"index": 4, "index": 4,
"hash_type": SIGHASH_ALL_TAPROOT, "hash_type": SigHashType.SIGHASH_ALL_TAPROOT,
"result": unhexlify('9f90136737540ccc18707e1fd398ad222a1a7e4dd65cbfd22dbe4660191efa58'), "result": unhexlify('9f90136737540ccc18707e1fd398ad222a1a7e4dd65cbfd22dbe4660191efa58'),
}, },
] ]
@ -155,7 +155,7 @@ VECTORS = [
[ [
{ {
"index": 1, "index": 1,
"hash_type": SIGHASH_ALL_TAPROOT, "hash_type": SigHashType.SIGHASH_ALL_TAPROOT,
"result": unhexlify('07333acfe6dce8196f1ad62b2e039a3d9f0b6627bf955be767c519c0f8789ff4'), "result": unhexlify('07333acfe6dce8196f1ad62b2e039a3d9f0b6627bf955be767c519c0f8789ff4'),
}, },
] ]

View File

@ -5,7 +5,7 @@ from trezor.messages import TxInput
from trezor.messages import PrevOutput from trezor.messages import PrevOutput
from apps.common import coins from apps.common import coins
from apps.bitcoin.common import SIGHASH_ALL from apps.bitcoin.common import SigHashType
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
@ -213,7 +213,7 @@ class TestZcashZip243(unittest.TestCase):
self.assertEqual(hexlify(get_tx_hash(zip243.h_prevouts)), v["prevouts_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_prevouts)), v["prevouts_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_sequence)), v["sequence_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_sequence)), v["sequence_hash"])
self.assertEqual(hexlify(get_tx_hash(zip243.h_outputs)), v["outputs_hash"]) self.assertEqual(hexlify(get_tx_hash(zip243.h_outputs)), v["outputs_hash"])
self.assertEqual(hexlify(zip243.hash143(txi, [unhexlify(i["pubkey"])], 1, tx, coin, SIGHASH_ALL)), v["preimage_hash"]) self.assertEqual(hexlify(zip243.hash143(txi, [unhexlify(i["pubkey"])], 1, tx, coin, SigHashType.SIGHASH_ALL)), v["preimage_hash"])
if __name__ == "__main__": if __name__ == "__main__":