refactor(core/bitcoin): Use HashWriter in address derivation.

pull/1723/head
Andrew Kozlik 3 years ago committed by Andrew Kozlik
parent 2c003052f5
commit 7811204ed5

@ -3,13 +3,14 @@ from trezor.crypto import base58, cashaddr
from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType
from trezor.messages import MultisigRedeemScriptType
from trezor.utils import HashWriter
from apps.common import address_type
from apps.common.coininfo import CoinInfo
from .common import ecdsa_hash_pubkey, encode_bech32_address
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .scripts import output_script_multisig, output_script_native_p2wpkh_or_p2wsh
from .scripts import output_script_native_p2wpkh_or_p2wsh, write_output_script_multisig
if False:
from trezor.crypto import bip32
@ -75,25 +76,25 @@ def get_address(
def address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin")
redeem_script = output_script_multisig(pubkeys, m)
redeem_script_hash = coin.script_hash(redeem_script)
return address_p2sh(redeem_script_hash, coin)
redeem_script = HashWriter(coin.script_hash())
write_output_script_multisig(redeem_script, pubkeys, m)
return address_p2sh(redeem_script.get_digest(), coin)
def address_multisig_p2wsh_in_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin")
witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest()
return address_p2wsh_in_p2sh(witness_script_hash, coin)
witness_script_h = HashWriter(sha256())
write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh_in_p2sh(witness_script_h.get_digest(), coin)
def address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str:
if not hrp:
raise wire.ProcessError("Multisig not enabled on this coin")
witness_script = output_script_multisig(pubkeys, m)
witness_script_hash = sha256(witness_script).digest()
return address_p2wsh(witness_script_hash, hrp)
witness_script_h = HashWriter(sha256())
write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh(witness_script_h.get_digest(), hrp)
def address_pkh(pubkey: bytes, coin: CoinInfo) -> str:

@ -67,7 +67,7 @@ def verify_nonownership(
) -> bool:
try:
r = utils.BufferReader(proof)
if r.read(4) != _VERSION_MAGIC:
if r.read_memoryview(4) != _VERSION_MAGIC:
raise wire.DataError("Unknown format of proof of ownership")
flags = r.get()
@ -79,7 +79,7 @@ def verify_nonownership(
ownership_id = get_identifier(script_pubkey, keychain)
not_owned = True
for _ in range(id_count):
if utils.consteq(ownership_id, r.read(_OWNERSHIP_ID_LEN)):
if utils.consteq(ownership_id, r.read_memoryview(_OWNERSHIP_ID_LEN)):
not_owned = False
# Verify the BIP-322 SignatureProof.

@ -3,9 +3,9 @@ from trezor.utils import BufferReader
from apps.common.readers import read_bitcoin_varint
def read_bytes_prefixed(r: BufferReader) -> bytes:
def read_memoryview_prefixed(r: BufferReader) -> memoryview:
n = read_bitcoin_varint(r)
return r.read(n)
return r.read_memoryview(n)
def read_op_push(r: BufferReader) -> int:

@ -13,7 +13,7 @@ from .multisig import (
multisig_get_pubkeys,
multisig_pubkey_index,
)
from .readers import read_bytes_prefixed, read_op_push
from .readers import read_memoryview_prefixed, read_op_push
from .writers import (
op_push_length,
write_bytes_fixed,
@ -23,6 +23,8 @@ from .writers import (
)
if False:
from typing import Sequence
from trezor.messages import MultisigRedeemScriptType, TxInput
from apps.common.coininfo import CoinInfo
@ -116,7 +118,11 @@ def output_derive_script(address: str, coin: CoinInfo) -> bytes:
# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
# item 5 for details
def write_bip143_script_code_prefixed(
w: Writer, txi: TxInput, public_keys: list[bytes], threshold: int, coin: CoinInfo
w: Writer,
txi: TxInput,
public_keys: Sequence[bytes | memoryview],
threshold: int,
coin: CoinInfo,
) -> None:
if len(public_keys) > 1:
write_output_script_multisig(w, public_keys, threshold, prefixed=True)
@ -152,15 +158,15 @@ def write_input_script_p2pkh_or_p2sh_prefixed(
append_pubkey(w, pubkey)
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[bytes, bytes, int]:
def parse_input_script_p2pkh(script_sig: bytes) -> tuple[memoryview, memoryview, int]:
try:
r = utils.BufferReader(script_sig)
n = read_op_push(r)
signature = r.read(n - 1)
signature = r.read_memoryview(n - 1)
hash_type = r.get()
n = read_op_push(r)
pubkey = r.read()
pubkey = r.read_memoryview()
if len(pubkey) != n:
raise ValueError
except (ValueError, EOFError):
@ -286,7 +292,7 @@ def write_witness_p2wpkh(
write_bytes_prefixed(w, pubkey)
def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]:
def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, int]:
try:
r = utils.BufferReader(witness)
@ -295,10 +301,10 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[bytes, bytes, int]:
raise ValueError
n = read_bitcoin_varint(r)
signature = r.read(n - 1)
signature = r.read_memoryview(n - 1)
hash_type = r.get()
pubkey = read_bytes_prefixed(r)
pubkey = read_memoryview_prefixed(r)
if r.remaining_count():
raise ValueError
except (ValueError, EOFError):
@ -342,7 +348,9 @@ def write_witness_multisig(
write_output_script_multisig(w, pubkeys, multisig.m, prefixed=True)
def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int]]]:
def parse_witness_multisig(
witness: bytes,
) -> tuple[memoryview, list[tuple[memoryview, int]]]:
try:
r = utils.BufferReader(witness)
@ -356,11 +364,11 @@ def parse_witness_multisig(witness: bytes) -> tuple[bytes, list[tuple[bytes, int
signatures = []
for i in range(item_count - 2):
n = read_bitcoin_varint(r)
signature = r.read(n - 1)
signature = r.read_memoryview(n - 1)
hash_type = r.get()
signatures.append((signature, hash_type))
script = read_bytes_prefixed(r)
script = read_memoryview_prefixed(r)
if r.remaining_count():
raise ValueError
except (ValueError, EOFError):
@ -416,7 +424,7 @@ def write_input_script_multisig_prefixed(
def parse_input_script_multisig(
script_sig: bytes,
) -> tuple[bytes, list[tuple[bytes, int]]]:
) -> tuple[memoryview, list[tuple[memoryview, int]]]:
try:
r = utils.BufferReader(script_sig)
@ -427,12 +435,12 @@ def parse_input_script_multisig(
signatures = []
n = read_op_push(r)
while r.remaining_count() > n:
signature = r.read(n - 1)
signature = r.read_memoryview(n - 1)
hash_type = r.get()
signatures.append((signature, hash_type))
n = read_op_push(r)
script = r.read()
script = r.read_memoryview()
if len(script) != n:
raise ValueError
except (ValueError, EOFError):
@ -449,7 +457,7 @@ def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray:
def write_output_script_multisig(
w: Writer,
pubkeys: list[bytes],
pubkeys: Sequence[bytes | memoryview],
m: int,
prefixed: bool = False,
) -> None:
@ -470,11 +478,11 @@ def write_output_script_multisig(
w.append(0xAE) # OP_CHECKMULTISIG
def output_script_multisig_length(pubkeys: list[bytes], m: int) -> int:
def output_script_multisig_length(pubkeys: Sequence[bytes | memoryview], m: int) -> int:
return 1 + len(pubkeys) * (1 + 33) + 1 + 1 # see output_script_multisig
def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]:
def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
try:
r = utils.BufferReader(script)
@ -493,7 +501,7 @@ def parse_output_script_multisig(script: bytes) -> tuple[list[bytes], int]:
n = read_op_push(r)
if n != 33:
raise ValueError
public_keys.append(r.read(n))
public_keys.append(r.read_memoryview(n))
r.get() # ignore pubkey_count
if r.get() != 0xAE: # OP_CHECKMULTISIG
@ -550,9 +558,9 @@ def write_bip322_signature_proof(
w.append(0x00)
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[bytes, bytes]:
script_sig = read_bytes_prefixed(r)
witness = r.read()
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[memoryview, memoryview]:
script_sig = read_memoryview_prefixed(r)
witness = r.read_memoryview()
return script_sig, witness
@ -572,6 +580,6 @@ def append_signature(w: Writer, signature: bytes, hash_type: int) -> None:
w.append(hash_type)
def append_pubkey(w: Writer, pubkey: bytes) -> None:
def append_pubkey(w: Writer, pubkey: bytes | memoryview) -> None:
write_op_push(w, len(pubkey))
write_bytes_unchecked(w, pubkey)

@ -17,6 +17,8 @@ from .hash143 import Bip143Hash
from .tx_info import OriginalTxInfo, TxInfo
if False:
from typing import Sequence
from trezor.crypto import bip32
from trezor.messages import (
@ -396,7 +398,7 @@ class Bitcoin:
i: int,
txi: TxInput,
tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
script_pubkey: bytes,
) -> bytes:

@ -11,6 +11,7 @@ from . import helpers
from .bitcoin import Bitcoin
if False:
from typing import Sequence
from .tx_info import OriginalTxInfo, TxInfo
_SIGHASH_FORKID = const(0x40)
@ -43,7 +44,7 @@ class Bitcoinlike(Bitcoin):
i: int,
txi: TxInput,
tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
script_pubkey: bytes,
tx_hash: bytes | None = None,

@ -24,6 +24,8 @@ OUTPUT_SCRIPT_NULL_SSTXCHANGE = (
)
if False:
from typing import Sequence
from trezor.messages import (
SignTx,
TxInput,
@ -61,7 +63,7 @@ class DecredHash:
def preimage_hash(
self,
txi: TxInput,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: CoinInfo,

@ -7,7 +7,7 @@ from apps.common import coininfo
from .. import scripts, writers
if False:
from typing import Protocol
from typing import Protocol, Sequence
class Hash143(Protocol):
def add_input(self, txi: TxInput) -> None:
@ -19,7 +19,7 @@ if False:
def preimage_hash(
self,
txi: TxInput,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: coininfo.CoinInfo,
@ -48,7 +48,7 @@ class Bip143Hash:
def preimage_hash(
self,
txi: TxInput,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: coininfo.CoinInfo,

@ -24,6 +24,7 @@ from . import approvers, helpers
from .bitcoinlike import Bitcoinlike
if False:
from typing import Sequence
from apps.common import coininfo
from .hash143 import Hash143
from .tx_info import OriginalTxInfo, TxInfo
@ -49,7 +50,7 @@ class Zip243Hash:
def preimage_hash(
self,
txi: TxInput,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
tx: SignTx | PrevTx,
coin: coininfo.CoinInfo,
@ -138,7 +139,7 @@ class Zcashlike(Bitcoinlike):
i: int,
txi: TxInput,
tx_info: TxInfo | OriginalTxInfo,
public_keys: list[bytes],
public_keys: Sequence[bytes | memoryview],
threshold: int,
script_pubkey: bytes,
tx_hash: bytes | None = None,

@ -30,8 +30,8 @@ class SignatureVerifier:
coin: CoinInfo,
):
self.threshold = 1
self.public_keys: list[bytes] = []
self.signatures: list[tuple[bytes, int]] = []
self.public_keys: list[memoryview] = []
self.signatures: list[tuple[memoryview, int]] = []
if not script_sig:
if not witness:
@ -118,7 +118,7 @@ class SignatureVerifier:
raise wire.DataError("Invalid signature")
def decode_der_signature(der_signature: bytes) -> bytearray:
def decode_der_signature(der_signature: memoryview) -> bytearray:
seq = der.decode_seq(der_signature)
if len(seq) != 2 or any(len(i) > 32 for i in seq):
raise ValueError

@ -1,6 +1,7 @@
from trezor.utils import ensure
if False:
from typing import Union
from trezor.utils import Writer
@ -68,7 +69,7 @@ def write_uint64_be(w: Writer, n: int) -> int:
return 8
def write_bytes_unchecked(w: Writer, b: bytes) -> int:
def write_bytes_unchecked(w: Writer, b: Union[bytes, memoryview]) -> int:
w.extend(b)
return len(b)

@ -42,7 +42,7 @@ def encode_int(i: bytes) -> bytes:
return b"\x02" + encode_length(len(i)) + i
def decode_int(r: BufferReader) -> bytes:
def decode_int(r: BufferReader) -> memoryview:
if r.get() != 0x02:
raise ValueError
@ -62,7 +62,7 @@ def decode_int(r: BufferReader) -> bytes:
if r.peek() == 0x00:
raise ValueError # excessive zero-padding
return r.read(n)
return r.read_memoryview(n)
def encode_seq(seq: tuple) -> bytes:
@ -72,7 +72,7 @@ def encode_seq(seq: tuple) -> bytes:
return b"\x30" + encode_length(len(res)) + res
def decode_seq(data: bytes) -> list[bytes]:
def decode_seq(data: memoryview) -> list[memoryview]:
r = BufferReader(data)
if r.get() != 0x30:

@ -219,8 +219,11 @@ class BufferWriter:
class BufferReader:
"""Seekable and readable view into a buffer."""
def __init__(self, buffer: bytes) -> None:
self.buffer = buffer
def __init__(self, buffer: Union[bytes, memoryview]) -> None:
if isinstance(buffer, memoryview):
self.buffer = buffer
else:
self.buffer = memoryview(buffer)
self.offset = 0
def seek(self, offset: int) -> None:
@ -251,7 +254,15 @@ class BufferReader:
If `length` is unspecified, reads all remaining data.
Note that this method makes a copy of the data. To avoid allocation, use
`readinto()`.
`readinto()`. To avoid copying use `read_memoryview()`.
"""
return bytes(self.read_memoryview(length))
def read_memoryview(self, length: int | None = None) -> memoryview:
"""Read and return a memoryview of exactly `length` bytes, or raise
EOFError.
If `length` is unspecified, reads all remaining data.
"""
if length is None:
ret = self.buffer[self.offset :]

Loading…
Cancel
Save