mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
refactor(core/bitcoin): Use HashWriter in address derivation.
This commit is contained in:
parent
2c003052f5
commit
7811204ed5
@ -3,13 +3,14 @@ from trezor.crypto import base58, cashaddr
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.enums import InputScriptType
|
||||
from trezor.messages import MultisigRedeemScriptType
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
from apps.common import address_type
|
||||
from apps.common.coininfo import CoinInfo
|
||||
|
||||
from .common import ecdsa_hash_pubkey, encode_bech32_address
|
||||
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
|
||||
from .scripts import output_script_multisig, output_script_native_p2wpkh_or_p2wsh
|
||||
from .scripts import output_script_native_p2wpkh_or_p2wsh, write_output_script_multisig
|
||||
|
||||
if False:
|
||||
from trezor.crypto import bip32
|
||||
@ -75,25 +76,25 @@ def get_address(
|
||||
def address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
|
||||
if coin.address_type_p2sh is None:
|
||||
raise wire.ProcessError("Multisig not enabled on this coin")
|
||||
redeem_script = output_script_multisig(pubkeys, m)
|
||||
redeem_script_hash = coin.script_hash(redeem_script)
|
||||
return address_p2sh(redeem_script_hash, coin)
|
||||
redeem_script = HashWriter(coin.script_hash())
|
||||
write_output_script_multisig(redeem_script, pubkeys, m)
|
||||
return address_p2sh(redeem_script.get_digest(), coin)
|
||||
|
||||
|
||||
def address_multisig_p2wsh_in_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
|
||||
if coin.address_type_p2sh is None:
|
||||
raise wire.ProcessError("Multisig not enabled on this coin")
|
||||
witness_script = output_script_multisig(pubkeys, m)
|
||||
witness_script_hash = sha256(witness_script).digest()
|
||||
return address_p2wsh_in_p2sh(witness_script_hash, coin)
|
||||
witness_script_h = HashWriter(sha256())
|
||||
write_output_script_multisig(witness_script_h, pubkeys, m)
|
||||
return address_p2wsh_in_p2sh(witness_script_h.get_digest(), coin)
|
||||
|
||||
|
||||
def address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str:
|
||||
if not hrp:
|
||||
raise wire.ProcessError("Multisig not enabled on this coin")
|
||||
witness_script = output_script_multisig(pubkeys, m)
|
||||
witness_script_hash = sha256(witness_script).digest()
|
||||
return address_p2wsh(witness_script_hash, hrp)
|
||||
witness_script_h = HashWriter(sha256())
|
||||
write_output_script_multisig(witness_script_h, pubkeys, m)
|
||||
return address_p2wsh(witness_script_h.get_digest(), hrp)
|
||||
|
||||
|
||||
def address_pkh(pubkey: bytes, coin: CoinInfo) -> str:
|
||||
|
@ -67,7 +67,7 @@ def verify_nonownership(
|
||||
) -> bool:
|
||||
try:
|
||||
r = utils.BufferReader(proof)
|
||||
if r.read(4) != _VERSION_MAGIC:
|
||||
if r.read_memoryview(4) != _VERSION_MAGIC:
|
||||
raise wire.DataError("Unknown format of proof of ownership")
|
||||
|
||||
flags = r.get()
|
||||
@ -79,7 +79,7 @@ def verify_nonownership(
|
||||
ownership_id = get_identifier(script_pubkey, keychain)
|
||||
not_owned = True
|
||||
for _ in range(id_count):
|
||||
if utils.consteq(ownership_id, r.read(_OWNERSHIP_ID_LEN)):
|
||||
if utils.consteq(ownership_id, r.read_memoryview(_OWNERSHIP_ID_LEN)):
|
||||
not_owned = False
|
||||
|
||||
# Verify the BIP-322 SignatureProof.
|
||||
|
@ -3,9 +3,9 @@ from trezor.utils import BufferReader
|
||||
from apps.common.readers import read_bitcoin_varint
|
||||
|
||||
|
||||
def read_bytes_prefixed(r: BufferReader) -> bytes:
|
||||
def read_memoryview_prefixed(r: BufferReader) -> memoryview:
|
||||
n = read_bitcoin_varint(r)
|
||||
return r.read(n)
|
||||
return r.read_memoryview(n)
|
||||
|
||||
|
||||
def read_op_push(r: BufferReader) -> int:
|
||||
|
@ -13,7 +13,7 @@ from .multisig import (
|
||||
multisig_get_pubkeys,
|
||||
multisig_pubkey_index,
|
||||
)
|
||||
from .readers import read_bytes_prefixed, read_op_push
|
||||
from .readers import read_memoryview_prefixed, read_op_push
|
||||
from .writers import (
|
||||
op_push_length,
|
||||
write_bytes_fixed,
|
||||
@ -23,6 +23,8 @@ from .writers import (
|
||||
)
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
|
||||
from trezor.messages import MultisigRedeemScriptType, TxInput
|
||||
|
||||
from apps.common.coininfo import CoinInfo
|
||||
@ -116,7 +118,11 @@ def output_derive_script(address: str, coin: CoinInfo) -> bytes:
|
||||
# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
|
||||
# item 5 for details
|
||||
def write_bip143_script_code_prefixed(
|
||||
w: Writer, txi: TxInput, public_keys: list[bytes], threshold: int, coin: CoinInfo
|
||||
w: Writer,
|
||||
txi: TxInput,
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
coin: CoinInfo,
|
||||
) -> None:
|
||||
if len(public_keys) > 1:
|
||||
write_output_script_multisig(w, public_keys, threshold, prefixed=True)
|
||||
@ -152,15 +158,15 @@ def write_input_script_p2pkh_or_p2sh_prefixed(
|
||||
append_pubkey(w, pubkey)
|
||||
|
||||
|
||||
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[bytes, bytes, int]:
|
||||
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[memoryview, memoryview, int]:
|
||||
try:
|
||||
r = utils.BufferReader(script_sig)
|
||||
n = read_op_push(r)
|
||||
signature = r.read(n - 1)
|
||||
signature = r.read_memoryview(n - 1)
|
||||
hash_type = r.get()
|
||||
|
||||
n = read_op_push(r)
|
||||
pubkey = r.read()
|
||||
pubkey = r.read_memoryview()
|
||||
if len(pubkey) != n:
|
||||
raise ValueError
|
||||
except (ValueError, EOFError):
|
||||
@ -286,7 +292,7 @@ def write_witness_p2wpkh(
|
||||
write_bytes_prefixed(w, pubkey)
|
||||
|
||||
|
||||
def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]:
|
||||
def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]:
|
||||
try:
|
||||
r = utils.BufferReader(witness)
|
||||
|
||||
@ -295,10 +301,10 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]:
|
||||
raise ValueError
|
||||
|
||||
n = read_bitcoin_varint(r)
|
||||
signature = r.read(n - 1)
|
||||
signature = r.read_memoryview(n - 1)
|
||||
hash_type = r.get()
|
||||
|
||||
pubkey = read_bytes_prefixed(r)
|
||||
pubkey = read_memoryview_prefixed(r)
|
||||
if r.remaining_count():
|
||||
raise ValueError
|
||||
except (ValueError, EOFError):
|
||||
@ -342,7 +348,9 @@ def write_witness_multisig(
|
||||
write_output_script_multisig(w, pubkeys, multisig.m, prefixed=True)
|
||||
|
||||
|
||||
def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int]]]:
|
||||
def parse_witness_multisig(
|
||||
witness: bytes,
|
||||
) -> tuple[memoryview, list[tuple[memoryview, int]]]:
|
||||
try:
|
||||
r = utils.BufferReader(witness)
|
||||
|
||||
@ -356,11 +364,11 @@ def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int
|
||||
signatures = []
|
||||
for i in range(item_count - 2):
|
||||
n = read_bitcoin_varint(r)
|
||||
signature = r.read(n - 1)
|
||||
signature = r.read_memoryview(n - 1)
|
||||
hash_type = r.get()
|
||||
signatures.append((signature, hash_type))
|
||||
|
||||
script = read_bytes_prefixed(r)
|
||||
script = read_memoryview_prefixed(r)
|
||||
if r.remaining_count():
|
||||
raise ValueError
|
||||
except (ValueError, EOFError):
|
||||
@ -416,7 +424,7 @@ def write_input_script_multisig_prefixed(
|
||||
|
||||
def parse_input_script_multisig(
|
||||
script_sig: bytes,
|
||||
) -> tuple[bytes, list[tuple[bytes, int]]]:
|
||||
) -> tuple[memoryview, list[tuple[memoryview, int]]]:
|
||||
try:
|
||||
r = utils.BufferReader(script_sig)
|
||||
|
||||
@ -427,12 +435,12 @@ def parse_input_script_multisig(
|
||||
signatures = []
|
||||
n = read_op_push(r)
|
||||
while r.remaining_count() > n:
|
||||
signature = r.read(n - 1)
|
||||
signature = r.read_memoryview(n - 1)
|
||||
hash_type = r.get()
|
||||
signatures.append((signature, hash_type))
|
||||
n = read_op_push(r)
|
||||
|
||||
script = r.read()
|
||||
script = r.read_memoryview()
|
||||
if len(script) != n:
|
||||
raise ValueError
|
||||
except (ValueError, EOFError):
|
||||
@ -449,7 +457,7 @@ def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray:
|
||||
|
||||
def write_output_script_multisig(
|
||||
w: Writer,
|
||||
pubkeys: list[bytes],
|
||||
pubkeys: Sequence[bytes | memoryview],
|
||||
m: int,
|
||||
prefixed: bool = False,
|
||||
) -> None:
|
||||
@ -470,11 +478,11 @@ def write_output_script_multisig(
|
||||
w.append(0xAE) # OP_CHECKMULTISIG
|
||||
|
||||
|
||||
def output_script_multisig_length(pubkeys: list[bytes], m: int) -> int:
|
||||
def output_script_multisig_length(pubkeys: Sequence[bytes | memoryview], m: int) -> int:
|
||||
return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig
|
||||
|
||||
|
||||
def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]:
|
||||
def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
|
||||
try:
|
||||
r = utils.BufferReader(script)
|
||||
|
||||
@ -493,7 +501,7 @@ def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]:
|
||||
n = read_op_push(r)
|
||||
if n != 33:
|
||||
raise ValueError
|
||||
public_keys.append(r.read(n))
|
||||
public_keys.append(r.read_memoryview(n))
|
||||
|
||||
r.get() # ignore pubkey_count
|
||||
if r.get() != 0xAE: # OP_CHECKMULTISIG
|
||||
@ -550,9 +558,9 @@ def write_bip322_signature_proof(
|
||||
w.append(0x00)
|
||||
|
||||
|
||||
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[bytes, bytes]:
|
||||
script_sig = read_bytes_prefixed(r)
|
||||
witness = r.read()
|
||||
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[memoryview, memoryview]:
|
||||
script_sig = read_memoryview_prefixed(r)
|
||||
witness = r.read_memoryview()
|
||||
return script_sig, witness
|
||||
|
||||
|
||||
@ -572,6 +580,6 @@ def append_signature(w: Writer, signature: bytes, hash_type: int) -> None:
|
||||
w.append(hash_type)
|
||||
|
||||
|
||||
def append_pubkey(w: Writer, pubkey: bytes) -> None:
|
||||
def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None:
|
||||
write_op_push(w, len(pubkey))
|
||||
write_bytes_unchecked(w, pubkey)
|
||||
|
@ -17,6 +17,8 @@ from .hash143 import Bip143Hash
|
||||
from .tx_info import OriginalTxInfo, TxInfo
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
|
||||
from trezor.crypto import bip32
|
||||
|
||||
from trezor.messages import (
|
||||
@ -396,7 +398,7 @@ class Bitcoin:
|
||||
i: int,
|
||||
txi: TxInput,
|
||||
tx_info: TxInfo | OriginalTxInfo,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
script_pubkey: bytes,
|
||||
) -> bytes:
|
||||
|
@ -11,6 +11,7 @@ from . import helpers
|
||||
from .bitcoin import Bitcoin
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
from .tx_info import OriginalTxInfo, TxInfo
|
||||
|
||||
_SIGHASH_FORKID = const(0x40)
|
||||
@ -43,7 +44,7 @@ class Bitcoinlike(Bitcoin):
|
||||
i: int,
|
||||
txi: TxInput,
|
||||
tx_info: TxInfo | OriginalTxInfo,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
script_pubkey: bytes,
|
||||
tx_hash: bytes | None = None,
|
||||
|
@ -24,6 +24,8 @@ OUTPUT_SCRIPT_NULL_SSTXCHANGE = (
|
||||
)
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
|
||||
from trezor.messages import (
|
||||
SignTx,
|
||||
TxInput,
|
||||
@ -61,7 +63,7 @@ class DecredHash:
|
||||
def preimage_hash(
|
||||
self,
|
||||
txi: TxInput,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
tx: SignTx | PrevTx,
|
||||
coin: CoinInfo,
|
||||
|
@ -7,7 +7,7 @@ from apps.common import coininfo
|
||||
from .. import scripts, writers
|
||||
|
||||
if False:
|
||||
from typing import Protocol
|
||||
from typing import Protocol, Sequence
|
||||
|
||||
class Hash143(Protocol):
|
||||
def add_input(self, txi: TxInput) -> None:
|
||||
@ -19,7 +19,7 @@ if False:
|
||||
def preimage_hash(
|
||||
self,
|
||||
txi: TxInput,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
tx: SignTx | PrevTx,
|
||||
coin: coininfo.CoinInfo,
|
||||
@ -48,7 +48,7 @@ class Bip143Hash:
|
||||
def preimage_hash(
|
||||
self,
|
||||
txi: TxInput,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
tx: SignTx | PrevTx,
|
||||
coin: coininfo.CoinInfo,
|
||||
|
@ -24,6 +24,7 @@ from . import approvers, helpers
|
||||
from .bitcoinlike import Bitcoinlike
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
from apps.common import coininfo
|
||||
from .hash143 import Hash143
|
||||
from .tx_info import OriginalTxInfo, TxInfo
|
||||
@ -49,7 +50,7 @@ class Zip243Hash:
|
||||
def preimage_hash(
|
||||
self,
|
||||
txi: TxInput,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
tx: SignTx | PrevTx,
|
||||
coin: coininfo.CoinInfo,
|
||||
@ -138,7 +139,7 @@ class Zcashlike(Bitcoinlike):
|
||||
i: int,
|
||||
txi: TxInput,
|
||||
tx_info: TxInfo | OriginalTxInfo,
|
||||
public_keys: list[bytes],
|
||||
public_keys: Sequence[bytes | memoryview],
|
||||
threshold: int,
|
||||
script_pubkey: bytes,
|
||||
tx_hash: bytes | None = None,
|
||||
|
@ -30,8 +30,8 @@ class SignatureVerifier:
|
||||
coin: CoinInfo,
|
||||
):
|
||||
self.threshold = 1
|
||||
self.public_keys: list[bytes] = []
|
||||
self.signatures: list[tuple[bytes, int]] = []
|
||||
self.public_keys: list[memoryview] = []
|
||||
self.signatures: list[tuple[memoryview, int]] = []
|
||||
|
||||
if not script_sig:
|
||||
if not witness:
|
||||
@ -118,7 +118,7 @@ class SignatureVerifier:
|
||||
raise wire.DataError("Invalid signature")
|
||||
|
||||
|
||||
def decode_der_signature(der_signature: bytes) -> bytearray:
|
||||
def decode_der_signature(der_signature: memoryview) -> bytearray:
|
||||
seq = der.decode_seq(der_signature)
|
||||
if len(seq) != 2 or any(len(i) > 32 for i in seq):
|
||||
raise ValueError
|
||||
|
@ -1,6 +1,7 @@
|
||||
from trezor.utils import ensure
|
||||
|
||||
if False:
|
||||
from typing import Union
|
||||
from trezor.utils import Writer
|
||||
|
||||
|
||||
@ -68,7 +69,7 @@ def write_uint64_be(w: Writer, n: int) -> int:
|
||||
return 8
|
||||
|
||||
|
||||
def write_bytes_unchecked(w: Writer, b: bytes) -> int:
|
||||
def write_bytes_unchecked(w: Writer, b: Union[bytes, memoryview]) -> int:
|
||||
w.extend(b)
|
||||
return len(b)
|
||||
|
||||
|
@ -42,7 +42,7 @@ def encode_int(i: bytes) -> bytes:
|
||||
return b"\x02" + encode_length(len(i)) + i
|
||||
|
||||
|
||||
def decode_int(r: BufferReader) -> bytes:
|
||||
def decode_int(r: BufferReader) -> memoryview:
|
||||
if r.get() != 0x02:
|
||||
raise ValueError
|
||||
|
||||
@ -62,7 +62,7 @@ def decode_int(r: BufferReader) -> bytes:
|
||||
if r.peek() == 0x00:
|
||||
raise ValueError # excessive zero-padding
|
||||
|
||||
return r.read(n)
|
||||
return r.read_memoryview(n)
|
||||
|
||||
|
||||
def encode_seq(seq: tuple) -> bytes:
|
||||
@ -72,7 +72,7 @@ def encode_seq(seq: tuple) -> bytes:
|
||||
return b"\x30" + encode_length(len(res)) + res
|
||||
|
||||
|
||||
def decode_seq(data: bytes) -> list[bytes]:
|
||||
def decode_seq(data: memoryview) -> list[memoryview]:
|
||||
r = BufferReader(data)
|
||||
|
||||
if r.get() != 0x30:
|
||||
|
@ -219,8 +219,11 @@ class BufferWriter:
|
||||
class BufferReader:
|
||||
"""Seekable and readable view into a buffer."""
|
||||
|
||||
def __init__(self, buffer: bytes) -> None:
|
||||
self.buffer = buffer
|
||||
def __init__(self, buffer: Union[bytes, memoryview]) -> None:
|
||||
if isinstance(buffer, memoryview):
|
||||
self.buffer = buffer
|
||||
else:
|
||||
self.buffer = memoryview(buffer)
|
||||
self.offset = 0
|
||||
|
||||
def seek(self, offset: int) -> None:
|
||||
@ -251,7 +254,15 @@ class BufferReader:
|
||||
If `length` is unspecified, reads all remaining data.
|
||||
|
||||
Note that this method makes a copy of the data. To avoid allocation, use
|
||||
`readinto()`.
|
||||
`readinto()`. To avoid copying use `read_memoryview()`.
|
||||
"""
|
||||
return bytes(self.read_memoryview(length))
|
||||
|
||||
def read_memoryview(self, length: int | None = None) -> memoryview:
|
||||
"""Read and return a memoryview of exactly `length` bytes, or raise
|
||||
EOFError.
|
||||
|
||||
If `length` is unspecified, reads all remaining data.
|
||||
"""
|
||||
if length is None:
|
||||
ret = self.buffer[self.offset :]
|
||||
|
Loading…
Reference in New Issue
Block a user