refactor(core/bitcoin): Use HashWriter in address derivation.

pull/1723/head
Andrew Kozlik 3 years ago committed by Andrew Kozlik
parent 2c003052f5
commit 7811204ed5

@ -3,13 +3,14 @@ from trezor.crypto import base58, cashaddr
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import MultisigRedeemScriptType from trezor.messages import MultisigRedeemScriptType
from trezor.utils import HashWriter
from apps.common import address_type from apps.common import address_type
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from .common import ecdsa_hash_pubkey, encode_bech32_address from .common import ecdsa_hash_pubkey, encode_bech32_address
from .multisig import multisig_get_pubkeys, multisig_pubkey_index from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .scripts import output_script_multisig, output_script_native_p2wpkh_or_p2wsh from .scripts import output_script_native_p2wpkh_or_p2wsh, write_output_script_multisig
if False: if False:
from trezor.crypto import bip32 from trezor.crypto import bip32
@ -75,25 +76,25 @@ def get_address(
def address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str: def address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin") raise wire.ProcessError("Multisig not enabled on this coin")
redeem_script = output_script_multisig(pubkeys, m) redeem_script = HashWriter(coin.script_hash())
redeem_script_hash = coin.script_hash(redeem_script) write_output_script_multisig(redeem_script, pubkeys, m)
return address_p2sh(redeem_script_hash, coin) return address_p2sh(redeem_script.get_digest(), coin)
def address_multisig_p2wsh_in_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str: def address_multisig_p2wsh_in_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin") raise wire.ProcessError("Multisig not enabled on this coin")
witness_script = output_script_multisig(pubkeys, m) witness_script_h = HashWriter(sha256())
witness_script_hash = sha256(witness_script).digest() write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh_in_p2sh(witness_script_hash, coin) return address_p2wsh_in_p2sh(witness_script_h.get_digest(), coin)
def address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str: def address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str:
if not hrp: if not hrp:
raise wire.ProcessError("Multisig not enabled on this coin") raise wire.ProcessError("Multisig not enabled on this coin")
witness_script = output_script_multisig(pubkeys, m) witness_script_h = HashWriter(sha256())
witness_script_hash = sha256(witness_script).digest() write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh(witness_script_hash, hrp) return address_p2wsh(witness_script_h.get_digest(), hrp)
def address_pkh(pubkey: bytes, coin: CoinInfo) -> str: def address_pkh(pubkey: bytes, coin: CoinInfo) -> str:

@ -67,7 +67,7 @@ def verify_nonownership(
) -> bool: ) -> bool:
try: try:
r = utils.BufferReader(proof) r = utils.BufferReader(proof)
if r.read(4) != _VERSION_MAGIC: if r.read_memoryview(4) != _VERSION_MAGIC:
raise wire.DataError("Unknown format of proof of ownership") raise wire.DataError("Unknown format of proof of ownership")
flags = r.get() flags = r.get()
@ -79,7 +79,7 @@ def verify_nonownership(
ownership_id = get_identifier(script_pubkey, keychain) ownership_id = get_identifier(script_pubkey, keychain)
not_owned = True not_owned = True
for _ in range(id_count): for _ in range(id_count):
if utils.consteq(ownership_id, r.read(_OWNERSHIP_ID_LEN)): if utils.consteq(ownership_id, r.read_memoryview(_OWNERSHIP_ID_LEN)):
not_owned = False not_owned = False
# Verify the BIP-322 SignatureProof. # Verify the BIP-322 SignatureProof.

@ -3,9 +3,9 @@ from trezor.utils import BufferReader
from apps.common.readers import read_bitcoin_varint from apps.common.readers import read_bitcoin_varint
def read_bytes_prefixed(r: BufferReader) -> bytes: def read_memoryview_prefixed(r: BufferReader) -> memoryview:
n = read_bitcoin_varint(r) n = read_bitcoin_varint(r)
return r.read(n) return r.read_memoryview(n)
def read_op_push(r: BufferReader) -> int: def read_op_push(r: BufferReader) -> int:

@ -13,7 +13,7 @@ from .multisig import (
multisig_get_pubkeys, multisig_get_pubkeys,
multisig_pubkey_index, multisig_pubkey_index,
) )
from .readers import read_bytes_prefixed, read_op_push from .readers import read_memoryview_prefixed, read_op_push
from .writers import ( from .writers import (
op_push_length, op_push_length,
write_bytes_fixed, write_bytes_fixed,
@ -23,6 +23,8 @@ from .writers import (
) )
if False: if False:
from typing import Sequence
from trezor.messages import MultisigRedeemScriptType, TxInput from trezor.messages import MultisigRedeemScriptType, TxInput
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -116,7 +118,11 @@ def output_derive_script(address: str, coin: CoinInfo) -> bytes:
# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification # see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
# item 5 for details # item 5 for details
def write_bip143_script_code_prefixed( def write_bip143_script_code_prefixed(
w: Writer, txi: TxInput, public_keys: list[bytes], threshold: int, coin: CoinInfo w: Writer,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
coin: CoinInfo,
) -> None: ) -> None:
if len(public_keys) > 1: if len(public_keys) > 1:
write_output_script_multisig(w, public_keys, threshold, prefixed=True) write_output_script_multisig(w, public_keys, threshold, prefixed=True)
@ -152,15 +158,15 @@ def write_input_script_p2pkh_or_p2sh_prefixed(
append_pubkey(w, pubkey) append_pubkey(w, pubkey)
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[bytes, bytes, int]: def parse_input_script_p2pkh(script_sig: bytes) -> tuple[memoryview, memoryview, int]:
try: try:
r = utils.BufferReader(script_sig) r = utils.BufferReader(script_sig)
n = read_op_push(r) n = read_op_push(r)
signature = r.read(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() hash_type = r.get()
n = read_op_push(r) n = read_op_push(r)
pubkey = r.read() pubkey = r.read_memoryview()
if len(pubkey) != n: if len(pubkey) != n:
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
@ -286,7 +292,7 @@ def write_witness_p2wpkh(
write_bytes_prefixed(w, pubkey) write_bytes_prefixed(w, pubkey)
def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]: def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]:
try: try:
r = utils.BufferReader(witness) r = utils.BufferReader(witness)
@ -295,10 +301,10 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]:
raise ValueError raise ValueError
n = read_bitcoin_varint(r) n = read_bitcoin_varint(r)
signature = r.read(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() hash_type = r.get()
pubkey = read_bytes_prefixed(r) pubkey = read_memoryview_prefixed(r)
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
@ -342,7 +348,9 @@ def write_witness_multisig(
write_output_script_multisig(w, pubkeys, multisig.m, prefixed=True) write_output_script_multisig(w, pubkeys, multisig.m, prefixed=True)
def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int]]]: def parse_witness_multisig(
witness: bytes,
) -> tuple[memoryview, list[tuple[memoryview, int]]]:
try: try:
r = utils.BufferReader(witness) r = utils.BufferReader(witness)
@ -356,11 +364,11 @@ def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int
signatures = [] signatures = []
for i in range(item_count - 2): for i in range(item_count - 2):
n = read_bitcoin_varint(r) n = read_bitcoin_varint(r)
signature = r.read(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() hash_type = r.get()
signatures.append((signature, hash_type)) signatures.append((signature, hash_type))
script = read_bytes_prefixed(r) script = read_memoryview_prefixed(r)
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
@ -416,7 +424,7 @@ def write_input_script_multisig_prefixed(
def parse_input_script_multisig( def parse_input_script_multisig(
script_sig: bytes, script_sig: bytes,
) -> tuple[bytes, list[tuple[bytes, int]]]: ) -> tuple[memoryview, list[tuple[memoryview, int]]]:
try: try:
r = utils.BufferReader(script_sig) r = utils.BufferReader(script_sig)
@ -427,12 +435,12 @@ def parse_input_script_multisig(
signatures = [] signatures = []
n = read_op_push(r) n = read_op_push(r)
while r.remaining_count() > n: while r.remaining_count() > n:
signature = r.read(n - 1) signature = r.read_memoryview(n - 1)
hash_type = r.get() hash_type = r.get()
signatures.append((signature, hash_type)) signatures.append((signature, hash_type))
n = read_op_push(r) n = read_op_push(r)
script = r.read() script = r.read_memoryview()
if len(script) != n: if len(script) != n:
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
@ -449,7 +457,7 @@ def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray:
def write_output_script_multisig( def write_output_script_multisig(
w: Writer, w: Writer,
pubkeys: list[bytes], pubkeys: Sequence[bytes | memoryview],
m: int, m: int,
prefixed: bool = False, prefixed: bool = False,
) -> None: ) -> None:
@ -470,11 +478,11 @@ def write_output_script_multisig(
w.append(0xAE) # OP_CHECKMULTISIG w.append(0xAE) # OP_CHECKMULTISIG
def output_script_multisig_length(pubkeys: list[bytes], m: int) -> int: def output_script_multisig_length(pubkeys: Sequence[bytes | memoryview], m: int) -> int:
return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig
def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]: def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
try: try:
r = utils.BufferReader(script) r = utils.BufferReader(script)
@ -493,7 +501,7 @@ def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]:
n = read_op_push(r) n = read_op_push(r)
if n != 33: if n != 33:
raise ValueError raise ValueError
public_keys.append(r.read(n)) public_keys.append(r.read_memoryview(n))
r.get() # ignore pubkey_count r.get() # ignore pubkey_count
if r.get() != 0xAE: # OP_CHECKMULTISIG if r.get() != 0xAE: # OP_CHECKMULTISIG
@ -550,9 +558,9 @@ def write_bip322_signature_proof(
w.append(0x00) w.append(0x00)
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[bytes, bytes]: def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[memoryview, memoryview]:
script_sig = read_bytes_prefixed(r) script_sig = read_memoryview_prefixed(r)
witness = r.read() witness = r.read_memoryview()
return script_sig, witness return script_sig, witness
@ -572,6 +580,6 @@ def append_signature(w: Writer, signature: bytes, hash_type: int) -> None:
w.append(hash_type) w.append(hash_type)
def append_pubkey(w: Writer, pubkey: bytes) -> None: def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None:
write_op_push(w, len(pubkey)) write_op_push(w, len(pubkey))
write_bytes_unchecked(w, pubkey) write_bytes_unchecked(w, pubkey)

@ -17,6 +17,8 @@ from .hash143 import Bip143Hash
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
if False: if False:
from typing import Sequence
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.messages import ( from trezor.messages import (
@ -396,7 +398,7 @@ class Bitcoin:
i: int, i: int,
txi: TxInput, txi: TxInput,
tx_info: TxInfo | OriginalTxInfo, tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
) -> bytes: ) -> bytes:

@ -11,6 +11,7 @@ from . import helpers
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
if False: if False:
from typing import Sequence
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
_SIGHASH_FORKID = const(0x40) _SIGHASH_FORKID = const(0x40)
@ -43,7 +44,7 @@ class Bitcoinlike(Bitcoin):
i: int, i: int,
txi: TxInput, txi: TxInput,
tx_info: TxInfo | OriginalTxInfo, tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: bytes | None = None, tx_hash: bytes | None = None,

@ -24,6 +24,8 @@ OUTPUT_SCRIPT_NULL_SSTXCHANGE = (
) )
if False: if False:
from typing import Sequence
from trezor.messages import ( from trezor.messages import (
SignTx, SignTx,
TxInput, TxInput,
@ -61,7 +63,7 @@ class DecredHash:
def preimage_hash( def preimage_hash(
self, self,
txi: TxInput, txi: TxInput,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: CoinInfo, coin: CoinInfo,

@ -7,7 +7,7 @@ from apps.common import coininfo
from .. import scripts, writers from .. import scripts, writers
if False: if False:
from typing import Protocol from typing import Protocol, Sequence
class Hash143(Protocol): class Hash143(Protocol):
def add_input(self, txi: TxInput) -> None: def add_input(self, txi: TxInput) -> None:
@ -19,7 +19,7 @@ if False:
def preimage_hash( def preimage_hash(
self, self,
txi: TxInput, txi: TxInput,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
@ -48,7 +48,7 @@ class Bip143Hash:
def preimage_hash( def preimage_hash(
self, self,
txi: TxInput, txi: TxInput,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,

@ -24,6 +24,7 @@ from . import approvers, helpers
from .bitcoinlike import Bitcoinlike from .bitcoinlike import Bitcoinlike
if False: if False:
from typing import Sequence
from apps.common import coininfo from apps.common import coininfo
from .hash143 import Hash143 from .hash143 import Hash143
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
@ -49,7 +50,7 @@ class Zip243Hash:
def preimage_hash( def preimage_hash(
self, self,
txi: TxInput, txi: TxInput,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
@ -138,7 +139,7 @@ class Zcashlike(Bitcoinlike):
i: int, i: int,
txi: TxInput, txi: TxInput,
tx_info: TxInfo | OriginalTxInfo, tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
script_pubkey: bytes, script_pubkey: bytes,
tx_hash: bytes | None = None, tx_hash: bytes | None = None,

@ -30,8 +30,8 @@ class SignatureVerifier:
coin: CoinInfo, coin: CoinInfo,
): ):
self.threshold = 1 self.threshold = 1
self.public_keys: list[bytes] = [] self.public_keys: list[memoryview] = []
self.signatures: list[tuple[bytes, int]] = [] self.signatures: list[tuple[memoryview, int]] = []
if not script_sig: if not script_sig:
if not witness: if not witness:
@ -118,7 +118,7 @@ class SignatureVerifier:
raise wire.DataError("Invalid signature") raise wire.DataError("Invalid signature")
def decode_der_signature(der_signature: bytes) -> bytearray: def decode_der_signature(der_signature: memoryview) -> bytearray:
seq = der.decode_seq(der_signature) seq = der.decode_seq(der_signature)
if len(seq) != 2 or any(len(i) > 32 for i in seq): if len(seq) != 2 or any(len(i) > 32 for i in seq):
raise ValueError raise ValueError

@ -1,6 +1,7 @@
from trezor.utils import ensure from trezor.utils import ensure
if False: if False:
from typing import Union
from trezor.utils import Writer from trezor.utils import Writer
@ -68,7 +69,7 @@ def write_uint64_be(w: Writer, n: int) -> int:
return 8 return 8
def write_bytes_unchecked(w: Writer, b: bytes) -> int: def write_bytes_unchecked(w: Writer, b: Union[bytes, memoryview]) -> int:
w.extend(b) w.extend(b)
return len(b) return len(b)

@ -42,7 +42,7 @@ def encode_int(i: bytes) -> bytes:
return b"\x02" + encode_length(len(i)) + i return b"\x02" + encode_length(len(i)) + i
def decode_int(r: BufferReader) -> bytes: def decode_int(r: BufferReader) -> memoryview:
if r.get() != 0x02: if r.get() != 0x02:
raise ValueError raise ValueError
@ -62,7 +62,7 @@ def decode_int(r: BufferReader) -> bytes:
if r.peek() == 0x00: if r.peek() == 0x00:
raise ValueError # excessive zero-padding raise ValueError # excessive zero-padding
return r.read(n) return r.read_memoryview(n)
def encode_seq(seq: tuple) -> bytes: def encode_seq(seq: tuple) -> bytes:
@ -72,7 +72,7 @@ def encode_seq(seq: tuple) -> bytes:
return b"\x30" + encode_length(len(res)) + res return b"\x30" + encode_length(len(res)) + res
def decode_seq(data: bytes) -> list[bytes]: def decode_seq(data: memoryview) -> list[memoryview]:
r = BufferReader(data) r = BufferReader(data)
if r.get() != 0x30: if r.get() != 0x30:

@ -219,8 +219,11 @@ class BufferWriter:
class BufferReader: class BufferReader:
"""Seekable and readable view into a buffer.""" """Seekable and readable view into a buffer."""
def __init__(self, buffer: bytes) -> None: def __init__(self, buffer: Union[bytes, memoryview]) -> None:
self.buffer = buffer if isinstance(buffer, memoryview):
self.buffer = buffer
else:
self.buffer = memoryview(buffer)
self.offset = 0 self.offset = 0
def seek(self, offset: int) -> None: def seek(self, offset: int) -> None:
@ -251,7 +254,15 @@ class BufferReader:
If `length` is unspecified, reads all remaining data. If `length` is unspecified, reads all remaining data.
Note that this method makes a copy of the data. To avoid allocation, use Note that this method makes a copy of the data. To avoid allocation, use
`readinto()`. `readinto()`. To avoid copying use `read_memoryview()`.
"""
return bytes(self.read_memoryview(length))
def read_memoryview(self, length: int | None = None) -> memoryview:
"""Read and return a memoryview of exactly `length` bytes, or raise
EOFError.
If `length` is unspecified, reads all remaining data.
""" """
if length is None: if length is None:
ret = self.buffer[self.offset :] ret = self.buffer[self.offset :]

Loading…
Cancel
Save