1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 12:28:09 +00:00

chore(core): decrease bitcoin size by 1740 bytes

This commit is contained in:
grdddj 2022-09-21 09:56:28 +02:00 committed by matejcik
parent 45b4b609db
commit 55bb61d404
45 changed files with 1186 additions and 975 deletions

View File

@ -1,22 +1,20 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor.crypto import base58
from trezor.crypto import base58, cashaddr
from trezor.crypto.curve import bip340
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.wire import ProcessError
from apps.common import address_type from apps.common import address_type
from apps.common.coininfo import CoinInfo
from .common import ecdsa_hash_pubkey, encode_bech32_address from .common import ecdsa_hash_pubkey, encode_bech32_address
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .scripts import output_script_native_segwit, write_output_script_multisig from .scripts import output_script_native_segwit, write_output_script_multisig
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import MultisigRedeemScriptType from trezor.messages import MultisigRedeemScriptType
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.common.coininfo import CoinInfo
from trezor.enums import InputScriptType
def get_address( def get_address(
@ -25,22 +23,27 @@ def get_address(
node: bip32.HDNode, node: bip32.HDNode,
multisig: MultisigRedeemScriptType | None = None, multisig: MultisigRedeemScriptType | None = None,
) -> str: ) -> str:
from trezor.enums import InputScriptType
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
node_public_key = node.public_key() # result_cache
if multisig: if multisig:
# Ensure that our public key is included in the multisig. # Ensure that our public key is included in the multisig.
multisig_pubkey_index(multisig, node.public_key()) multisig_pubkey_index(multisig, node_public_key)
if script_type in (InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG): if script_type in (InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG):
if multisig: # p2sh multisig if multisig: # p2sh multisig
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin") raise ProcessError("Multisig not enabled on this coin")
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2sh(pubkeys, multisig.m, coin) address = _address_multisig_p2sh(pubkeys, multisig.m, coin)
if coin.cashaddr_prefix is not None: if coin.cashaddr_prefix is not None:
address = address_to_cashaddr(address, coin) address = address_to_cashaddr(address, coin)
return address return address
if script_type == InputScriptType.SPENDMULTISIG: if script_type == InputScriptType.SPENDMULTISIG:
raise wire.ProcessError("Multisig details required") raise ProcessError("Multisig details required")
# p2pkh # p2pkh
address = node.address(coin.address_type) address = node.address(coin.address_type)
@ -50,63 +53,65 @@ def get_address(
elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or native p2wsh elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or native p2wsh
if not coin.segwit or not coin.bech32_prefix: if not coin.segwit or not coin.bech32_prefix:
raise wire.ProcessError("Segwit not enabled on this coin") raise ProcessError("Segwit not enabled on this coin")
# native p2wsh multisig # native p2wsh multisig
if multisig is not None: if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
return address_multisig_p2wsh(pubkeys, multisig.m, coin.bech32_prefix) return _address_multisig_p2wsh(pubkeys, multisig.m, coin.bech32_prefix)
# native p2wpkh # native p2wpkh
return address_p2wpkh(node.public_key(), coin) return address_p2wpkh(node_public_key, coin)
elif script_type == InputScriptType.SPENDTAPROOT: # taproot elif script_type == InputScriptType.SPENDTAPROOT: # taproot
if not coin.taproot or not coin.bech32_prefix: if not coin.taproot or not coin.bech32_prefix:
raise wire.ProcessError("Taproot not enabled on this coin") raise ProcessError("Taproot not enabled on this coin")
if multisig is not None: if multisig is not None:
raise wire.ProcessError("Multisig not supported for taproot") raise ProcessError("Multisig not supported for taproot")
return address_p2tr(node.public_key(), coin) return _address_p2tr(node_public_key, coin)
elif ( elif (
script_type == InputScriptType.SPENDP2SHWITNESS script_type == InputScriptType.SPENDP2SHWITNESS
): # p2wpkh or p2wsh nested in p2sh ): # p2wpkh or p2wsh nested in p2sh
if not coin.segwit or coin.address_type_p2sh is None: if not coin.segwit or coin.address_type_p2sh is None:
raise wire.ProcessError("Segwit not enabled on this coin") raise ProcessError("Segwit not enabled on this coin")
# p2wsh multisig nested in p2sh # p2wsh multisig nested in p2sh
if multisig is not None: if multisig is not None:
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
return address_multisig_p2wsh_in_p2sh(pubkeys, multisig.m, coin) return _address_multisig_p2wsh_in_p2sh(pubkeys, multisig.m, coin)
# p2wpkh nested in p2sh # p2wpkh nested in p2sh
return address_p2wpkh_in_p2sh(node.public_key(), coin) return address_p2wpkh_in_p2sh(node_public_key, coin)
else: else:
raise wire.ProcessError("Invalid script type") raise ProcessError("Invalid script type")
def address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str: def _address_multisig_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin") raise ProcessError("Multisig not enabled on this coin")
redeem_script = HashWriter(coin.script_hash()) redeem_script = HashWriter(coin.script_hash())
write_output_script_multisig(redeem_script, pubkeys, m) write_output_script_multisig(redeem_script, pubkeys, m)
return address_p2sh(redeem_script.get_digest(), coin) return address_p2sh(redeem_script.get_digest(), coin)
def address_multisig_p2wsh_in_p2sh(pubkeys: list[bytes], m: int, coin: CoinInfo) -> str: def _address_multisig_p2wsh_in_p2sh(
pubkeys: list[bytes], m: int, coin: CoinInfo
) -> str:
if coin.address_type_p2sh is None: if coin.address_type_p2sh is None:
raise wire.ProcessError("Multisig not enabled on this coin") raise ProcessError("Multisig not enabled on this coin")
witness_script_h = HashWriter(sha256()) witness_script_h = HashWriter(sha256())
write_output_script_multisig(witness_script_h, pubkeys, m) write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh_in_p2sh(witness_script_h.get_digest(), coin) return _address_p2wsh_in_p2sh(witness_script_h.get_digest(), coin)
def address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str: def _address_multisig_p2wsh(pubkeys: list[bytes], m: int, hrp: str) -> str:
if not hrp: if not hrp:
raise wire.ProcessError("Multisig not enabled on this coin") raise ProcessError("Multisig not enabled on this coin")
witness_script_h = HashWriter(sha256()) witness_script_h = HashWriter(sha256())
write_output_script_multisig(witness_script_h, pubkeys, m) write_output_script_multisig(witness_script_h, pubkeys, m)
return address_p2wsh(witness_script_h.get_digest(), hrp) return _address_p2wsh(witness_script_h.get_digest(), hrp)
def address_pkh(pubkey: bytes, coin: CoinInfo) -> str: def address_pkh(pubkey: bytes, coin: CoinInfo) -> str:
@ -126,7 +131,7 @@ def address_p2wpkh_in_p2sh(pubkey: bytes, coin: CoinInfo) -> str:
return address_p2sh(redeem_script_hash, coin) return address_p2sh(redeem_script_hash, coin)
def address_p2wsh_in_p2sh(witness_script_hash: bytes, coin: CoinInfo) -> str: def _address_p2wsh_in_p2sh(witness_script_hash: bytes, coin: CoinInfo) -> str:
redeem_script = output_script_native_segwit(0, witness_script_hash) redeem_script = output_script_native_segwit(0, witness_script_hash)
redeem_script_hash = coin.script_hash(redeem_script).digest() redeem_script_hash = coin.script_hash(redeem_script).digest()
return address_p2sh(redeem_script_hash, coin) return address_p2sh(redeem_script_hash, coin)
@ -138,17 +143,21 @@ def address_p2wpkh(pubkey: bytes, coin: CoinInfo) -> str:
return encode_bech32_address(coin.bech32_prefix, 0, pubkeyhash) return encode_bech32_address(coin.bech32_prefix, 0, pubkeyhash)
def address_p2wsh(witness_script_hash: bytes, hrp: str) -> str: def _address_p2wsh(witness_script_hash: bytes, hrp: str) -> str:
return encode_bech32_address(hrp, 0, witness_script_hash) return encode_bech32_address(hrp, 0, witness_script_hash)
def address_p2tr(pubkey: bytes, coin: CoinInfo) -> str: def _address_p2tr(pubkey: bytes, coin: CoinInfo) -> str:
from trezor.crypto.curve import bip340
assert coin.bech32_prefix is not None assert coin.bech32_prefix is not None
output_pubkey = bip340.tweak_public_key(pubkey[1:]) output_pubkey = bip340.tweak_public_key(pubkey[1:])
return encode_bech32_address(coin.bech32_prefix, 1, output_pubkey) return encode_bech32_address(coin.bech32_prefix, 1, output_pubkey)
def address_to_cashaddr(address: str, coin: CoinInfo) -> str: def address_to_cashaddr(address: str, coin: CoinInfo) -> str:
from trezor.crypto import cashaddr
assert coin.cashaddr_prefix is not None assert coin.cashaddr_prefix is not None
raw = base58.decode_check(address, coin.b58_hash) raw = base58.decode_check(address, coin.b58_hash)
version, data = raw[0], raw[1:] version, data = raw[0], raw[1:]

View File

@ -1,16 +1,11 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire
from trezor.messages import AuthorizeCoinJoin
from apps.common import authorization
from .common import BIP32_WALLET_DEPTH from .common import BIP32_WALLET_DEPTH
from .writers import write_bytes_prefixed
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ( from trezor.messages import (
AuthorizeCoinJoin,
GetOwnershipProof, GetOwnershipProof,
SignTx, SignTx,
TxInput, TxInput,
@ -27,14 +22,19 @@ class CoinJoinAuthorization:
self.params = params self.params = params
def check_get_ownership_proof(self, msg: GetOwnershipProof) -> bool: def check_get_ownership_proof(self, msg: GetOwnershipProof) -> bool:
from trezor import utils
from .writers import write_bytes_prefixed
params = self.params # local_cache_attribute
# Check whether the current params matches the parameters of the request. # Check whether the current params matches the parameters of the request.
coordinator = utils.empty_bytearray(1 + len(self.params.coordinator.encode())) coordinator = utils.empty_bytearray(1 + len(params.coordinator.encode()))
write_bytes_prefixed(coordinator, self.params.coordinator.encode()) write_bytes_prefixed(coordinator, params.coordinator.encode())
return ( return (
len(msg.address_n) >= BIP32_WALLET_DEPTH len(msg.address_n) >= BIP32_WALLET_DEPTH
and msg.address_n[:-BIP32_WALLET_DEPTH] == self.params.address_n and msg.address_n[:-BIP32_WALLET_DEPTH] == params.address_n
and msg.coin_name == self.params.coin_name and msg.coin_name == params.coin_name
and msg.script_type == self.params.script_type and msg.script_type == params.script_type
and msg.commitment_data.startswith(bytes(coordinator)) and msg.commitment_data.startswith(bytes(coordinator))
) )
@ -48,15 +48,22 @@ class CoinJoinAuthorization:
) )
def approve_sign_tx(self, msg: SignTx) -> bool: def approve_sign_tx(self, msg: SignTx) -> bool:
if self.params.max_rounds < 1 or msg.coin_name != self.params.coin_name: from apps.common import authorization
params = self.params # local_cache_attribute
if params.max_rounds < 1 or msg.coin_name != params.coin_name:
return False return False
self.params.max_rounds -= 1 params.max_rounds -= 1
authorization.set(self.params) authorization.set(params)
return True return True
def from_cached_message(auth_msg: MessageType) -> CoinJoinAuthorization: def from_cached_message(auth_msg: MessageType) -> CoinJoinAuthorization:
from trezor import wire
from trezor.messages import AuthorizeCoinJoin
if not AuthorizeCoinJoin.is_type_of(auth_msg): if not AuthorizeCoinJoin.is_type_of(auth_msg):
raise wire.ProcessError("Appropriate params was not found") raise wire.ProcessError("Appropriate params was not found")

View File

@ -8,7 +8,7 @@ if TYPE_CHECKING:
from trezor.messages import AuthorizeCoinJoin, Success from trezor.messages import AuthorizeCoinJoin, Success
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor import wire from trezor.wire import Context
_MAX_COORDINATOR_LEN = const(36) _MAX_COORDINATOR_LEN = const(36)
_MAX_ROUNDS = const(500) _MAX_ROUNDS = const(500)
@ -17,41 +17,44 @@ _MAX_COORDINATOR_FEE_RATE = 5 * pow(10, FEE_RATE_DECIMALS) # 5 %
@with_keychain @with_keychain
async def authorize_coinjoin( async def authorize_coinjoin(
ctx: wire.Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo ctx: Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
) -> Success: ) -> Success:
from trezor import wire
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.messages import Success from trezor.messages import Success
from trezor.ui.layouts import confirm_coinjoin, confirm_metadata from trezor.ui.layouts import confirm_coinjoin, confirm_metadata
from trezor.wire import DataError
from apps.common import authorization, safety_checks from apps.common import authorization, safety_checks
from apps.common.keychain import FORBIDDEN_KEY_PATH from apps.common.keychain import FORBIDDEN_KEY_PATH
from apps.common.paths import SLIP25_PURPOSE, validate_path from apps.common.paths import SLIP25_PURPOSE, validate_path
from .keychain import validate_path_against_script_type
from .common import BIP32_WALLET_DEPTH, format_fee_rate from .common import BIP32_WALLET_DEPTH, format_fee_rate
from .keychain import validate_path_against_script_type
safety_checks_is_strict = safety_checks.is_strict() # result_cache
address_n = msg.address_n # local_cache_attribute
if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all( if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all(
32 <= ord(x) <= 126 for x in msg.coordinator 32 <= ord(x) <= 126 for x in msg.coordinator
): ):
raise wire.DataError("Invalid coordinator name.") raise DataError("Invalid coordinator name.")
if msg.max_rounds > _MAX_ROUNDS and safety_checks.is_strict(): if msg.max_rounds > _MAX_ROUNDS and safety_checks_is_strict:
raise wire.DataError("The number of rounds is unexpectedly large.") raise DataError("The number of rounds is unexpectedly large.")
if ( if (
msg.max_coordinator_fee_rate > _MAX_COORDINATOR_FEE_RATE msg.max_coordinator_fee_rate > _MAX_COORDINATOR_FEE_RATE
and safety_checks.is_strict() and safety_checks_is_strict
): ):
raise wire.DataError("The coordination fee rate is unexpectedly large.") raise DataError("The coordination fee rate is unexpectedly large.")
if msg.max_fee_per_kvbyte > 10 * coin.maxfee_kb and safety_checks.is_strict(): if msg.max_fee_per_kvbyte > 10 * coin.maxfee_kb and safety_checks_is_strict:
raise wire.DataError("The fee per vbyte is unexpectedly large.") raise DataError("The fee per vbyte is unexpectedly large.")
if not msg.address_n: if not address_n:
raise wire.DataError("Empty path not allowed.") raise DataError("Empty path not allowed.")
if msg.address_n[0] != SLIP25_PURPOSE and safety_checks.is_strict(): if address_n[0] != SLIP25_PURPOSE and safety_checks_is_strict:
raise FORBIDDEN_KEY_PATH raise FORBIDDEN_KEY_PATH
max_fee_per_vbyte = format_fee_rate( max_fee_per_vbyte = format_fee_rate(
@ -65,7 +68,7 @@ async def authorize_coinjoin(
ctx, ctx,
keychain, keychain,
validation_path, validation_path,
msg.address_n[0] == SLIP25_PURPOSE, address_n[0] == SLIP25_PURPOSE,
validate_path_against_script_type( validate_path_against_script_type(
coin, address_n=validation_path, script_type=msg.script_type coin, address_n=validation_path, script_type=msg.script_type
), ),

View File

@ -2,17 +2,16 @@ from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor import wire
from trezor.crypto import bech32, bip32, der from trezor.crypto import bech32
from trezor.crypto.curve import bip340, secp256k1 from trezor.crypto.curve import bip340
from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType, OutputScriptType
from trezor.strings import format_amount
from trezor.utils import HashWriter, ensure
if TYPE_CHECKING: if TYPE_CHECKING:
from enum import IntEnum from enum import IntEnum
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from trezor.messages import TxInput from trezor.messages import TxInput
from trezor.utils import HashWriter
from trezor.crypto import bip32
else: else:
IntEnum = object IntEnum = object
@ -21,7 +20,11 @@ BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
class SigHashType(IntEnum): class SigHashType(IntEnum):
"""Enumeration type listing the supported signature hash types.""" """Enumeration type listing the supported signature hash types.
Class constants defined below don't need to be used in the code.
They are a list of all allowed incoming sighash types.
"""
# Signature hash type with the same semantics as SIGHASH_ALL, but instead # Signature hash type with the same semantics as SIGHASH_ALL, but instead
# of having to include the byte in the signature, it is implied. # of having to include the byte in the signature, it is implied.
@ -37,8 +40,6 @@ class SigHashType(IntEnum):
# Signature hash type with the same semantics as SIGHASH_ALL. Used in some # Signature hash type with the same semantics as SIGHASH_ALL. Used in some
# Bitcoin-like altcoins for replay protection. # Bitcoin-like altcoins for replay protection.
# NOTE: this seems to be unused, but when deleted, it breaks some tests
# (test_send_bch_external_presigned and test_send_btg_external_presigned)
SIGHASH_ALL_FORKID = 0x41 SIGHASH_ALL_FORKID = 0x41
@classmethod @classmethod
@ -100,6 +101,9 @@ NONSEGWIT_INPUT_SCRIPT_TYPES = (
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
from trezor.crypto import der
from trezor.crypto.curve import secp256k1
sig = secp256k1.sign(node.private_key(), digest) sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65])) sigder = der.encode_seq((sig[1:33], sig[33:65]))
return sigder return sigder
@ -112,6 +116,8 @@ def bip340_sign(node: bip32.HDNode, digest: bytes) -> bytes:
def ecdsa_hash_pubkey(pubkey: bytes, coin: CoinInfo) -> bytes: def ecdsa_hash_pubkey(pubkey: bytes, coin: CoinInfo) -> bytes:
from trezor.utils import ensure
if pubkey[0] == 0x04: if pubkey[0] == 0x04:
ensure(len(pubkey) == 65) # uncompressed format ensure(len(pubkey) == 65) # uncompressed format
elif pubkey[0] == 0x00: elif pubkey[0] == 0x00:
@ -178,6 +184,9 @@ def input_is_external_unverified(txi: TxInput) -> bool:
def tagged_hashwriter(tag: bytes) -> HashWriter: def tagged_hashwriter(tag: bytes) -> HashWriter:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
tag_digest = sha256(tag).digest() tag_digest = sha256(tag).digest()
ctx = sha256(tag_digest) ctx = sha256(tag_digest)
ctx.update(tag_digest) ctx.update(tag_digest)
@ -187,6 +196,8 @@ def tagged_hashwriter(tag: bytes) -> HashWriter:
def format_fee_rate( def format_fee_rate(
fee_rate: float, coin: CoinInfo, include_shortcut: bool = False fee_rate: float, coin: CoinInfo, include_shortcut: bool = False
) -> str: ) -> str:
from trezor.strings import format_amount
# Use format_amount to get correct thousands separator -- micropython's built-in # Use format_amount to get correct thousands separator -- micropython's built-in
# formatting doesn't add thousands sep to floating point numbers. # formatting doesn't add thousands sep to floating point numbers.
# We multiply by 100 to get a fixed-point integer with two decimal places, # We multiply by 100 to get a fixed-point integer with two decimal places,

View File

@ -1,19 +1,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto import bip32 from .keychain import with_keychain
from trezor.enums import InputScriptType
from trezor.messages import Address
from trezor.ui.layouts import show_address
from apps.common.address_mac import get_address_mac
from apps.common.paths import address_n_to_str, validate_path
from . import addresses
from .keychain import validate_path_against_script_type, with_keychain
from .multisig import multisig_pubkey_index
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetAddress, HDNodeType from trezor.messages import GetAddress, HDNodeType, Address
from trezor import wire from trezor import wire
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -22,6 +12,8 @@ if TYPE_CHECKING:
def _get_xpubs( def _get_xpubs(
coin: CoinInfo, xpub_magic: int, pubnodes: list[HDNodeType] coin: CoinInfo, xpub_magic: int, pubnodes: list[HDNodeType]
) -> list[str]: ) -> list[str]:
from trezor.crypto import bip32
result = [] result = []
for pubnode in pubnodes: for pubnode in pubnodes:
node = bip32.HDNode( node = bip32.HDNode(
@ -41,22 +33,37 @@ def _get_xpubs(
async def get_address( async def get_address(
ctx: wire.Context, msg: GetAddress, keychain: Keychain, coin: CoinInfo ctx: wire.Context, msg: GetAddress, keychain: Keychain, coin: CoinInfo
) -> Address: ) -> Address:
from trezor.enums import InputScriptType
from trezor.messages import Address
from trezor.ui.layouts import show_address
from apps.common.address_mac import get_address_mac
from apps.common.paths import address_n_to_str, validate_path
from . import addresses
from .keychain import validate_path_against_script_type
from .multisig import multisig_pubkey_index
multisig = msg.multisig # local_cache_attribute
address_n = msg.address_n # local_cache_attribute
script_type = msg.script_type # local_cache_attribute
if msg.show_display: if msg.show_display:
# skip soft-validation for silent calls # skip soft-validation for silent calls
await validate_path( await validate_path(
ctx, ctx,
keychain, keychain,
msg.address_n, address_n,
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),
) )
node = keychain.derive(msg.address_n) node = keychain.derive(address_n)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address = addresses.get_address(script_type, coin, node, multisig)
address_short = addresses.address_short(coin, address) address_short = addresses.address_short(coin, address)
address_case_sensitive = True address_case_sensitive = True
if coin.segwit and msg.script_type in ( if coin.segwit and script_type in (
InputScriptType.SPENDWITNESS, InputScriptType.SPENDWITNESS,
InputScriptType.SPENDTAPROOT, InputScriptType.SPENDTAPROOT,
): ):
@ -66,15 +73,15 @@ async def get_address(
mac: bytes | None = None mac: bytes | None = None
multisig_xpub_magic = coin.xpub_magic multisig_xpub_magic = coin.xpub_magic
if msg.multisig: if multisig:
if coin.segwit and not msg.ignore_xpub_magic: if coin.segwit and not msg.ignore_xpub_magic:
if ( if (
msg.script_type == InputScriptType.SPENDWITNESS script_type == InputScriptType.SPENDWITNESS
and coin.xpub_magic_multisig_segwit_native is not None and coin.xpub_magic_multisig_segwit_native is not None
): ):
multisig_xpub_magic = coin.xpub_magic_multisig_segwit_native multisig_xpub_magic = coin.xpub_magic_multisig_segwit_native
elif ( elif (
msg.script_type == InputScriptType.SPENDP2SHWITNESS script_type == InputScriptType.SPENDP2SHWITNESS
and coin.xpub_magic_multisig_segwit_p2sh is not None and coin.xpub_magic_multisig_segwit_p2sh is not None
): ):
multisig_xpub_magic = coin.xpub_magic_multisig_segwit_p2sh multisig_xpub_magic = coin.xpub_magic_multisig_segwit_p2sh
@ -82,33 +89,33 @@ async def get_address(
# Attach a MAC for single-sig addresses, but only if the path is standard # Attach a MAC for single-sig addresses, but only if the path is standard
# or if the user explicitly confirms a non-standard path. # or if the user explicitly confirms a non-standard path.
if msg.show_display or ( if msg.show_display or (
keychain.is_in_keychain(msg.address_n) keychain.is_in_keychain(address_n)
and validate_path_against_script_type(coin, msg) and validate_path_against_script_type(coin, msg)
): ):
mac = get_address_mac(address, coin.slip44, keychain) mac = get_address_mac(address, coin.slip44, keychain)
if msg.show_display: if msg.show_display:
if msg.multisig: if multisig:
if msg.multisig.nodes: if multisig.nodes:
pubnodes = msg.multisig.nodes pubnodes = multisig.nodes
else: else:
pubnodes = [hd.node for hd in msg.multisig.pubkeys] pubnodes = [hd.node for hd in multisig.pubkeys]
multisig_index = multisig_pubkey_index(msg.multisig, node.public_key()) multisig_index = multisig_pubkey_index(multisig, node.public_key())
title = f"Multisig {msg.multisig.m} of {len(pubnodes)}" title = f"Multisig {multisig.m} of {len(pubnodes)}"
await show_address( await show_address(
ctx, ctx,
address=address_short, address_short,
case_sensitive=address_case_sensitive, case_sensitive=address_case_sensitive,
title=title, title=title,
multisig_index=multisig_index, multisig_index=multisig_index,
xpubs=_get_xpubs(coin, multisig_xpub_magic, pubnodes), xpubs=_get_xpubs(coin, multisig_xpub_magic, pubnodes),
) )
else: else:
title = address_n_to_str(msg.address_n) title = address_n_to_str(address_n)
await show_address( await show_address(
ctx, ctx,
address=address_short, address_short,
address_qr=address, address_qr=address,
case_sensitive=address_case_sensitive, case_sensitive=address_case_sensitive,
title=title, title=title,

View File

@ -1,25 +1,30 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from .keychain import with_keychain
from trezor.enums import InputScriptType
from trezor.messages import OwnershipId
from apps.common.paths import validate_path
from . import addresses, common, scripts
from .keychain import validate_path_against_script_type, with_keychain
from .ownership import get_identifier
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetOwnershipId from trezor.messages import GetOwnershipId, OwnershipId
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@with_keychain @with_keychain
async def get_ownership_id( async def get_ownership_id(
ctx: wire.Context, msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo ctx: Context, msg: GetOwnershipId, keychain: Keychain, coin: CoinInfo
) -> OwnershipId: ) -> OwnershipId:
from trezor.wire import DataError
from trezor.enums import InputScriptType
from trezor.messages import OwnershipId
from apps.common.paths import validate_path
from . import addresses, common, scripts
from .keychain import validate_path_against_script_type
from .ownership import get_identifier
script_type = msg.script_type # local_cache_attribute
await validate_path( await validate_path(
ctx, ctx,
keychain, keychain,
@ -27,17 +32,17 @@ async def get_ownership_id(
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),
) )
if msg.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES: if script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Invalid script type") raise DataError("Invalid script type")
if msg.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES and not coin.segwit: if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES and not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin") raise DataError("Segwit not enabled on this coin")
if msg.script_type == InputScriptType.SPENDTAPROOT and not coin.taproot: if script_type == InputScriptType.SPENDTAPROOT and not coin.taproot:
raise wire.DataError("Taproot not enabled on this coin") raise DataError("Taproot not enabled on this coin")
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address = addresses.get_address(script_type, coin, node, msg.multisig)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_id = get_identifier(script_pubkey, keychain) ownership_id = get_identifier(script_pubkey, keychain)

View File

@ -1,18 +1,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import ui, wire from .keychain import with_keychain
from trezor.enums import InputScriptType
from trezor.messages import OwnershipProof
from trezor.ui.layouts import confirm_action, confirm_blob
from apps.common.paths import validate_path
from . import addresses, common, scripts
from .keychain import validate_path_against_script_type, with_keychain
from .ownership import generate_proof, get_identifier
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetOwnershipProof from trezor.messages import GetOwnershipProof, OwnershipProof
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .authorization import CoinJoinAuthorization from .authorization import CoinJoinAuthorization
@ -20,15 +12,30 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def get_ownership_proof( async def get_ownership_proof(
ctx: wire.Context, ctx: Context,
msg: GetOwnershipProof, msg: GetOwnershipProof,
keychain: Keychain, keychain: Keychain,
coin: CoinInfo, coin: CoinInfo,
authorization: CoinJoinAuthorization | None = None, authorization: CoinJoinAuthorization | None = None,
) -> OwnershipProof: ) -> OwnershipProof:
from trezor import ui
from trezor.wire import DataError, ProcessError
from trezor.enums import InputScriptType
from trezor.messages import OwnershipProof
from trezor.ui.layouts import confirm_action, confirm_blob
from apps.common.paths import validate_path
from . import addresses, common, scripts
from .keychain import validate_path_against_script_type
from .ownership import generate_proof, get_identifier
script_type = msg.script_type # local_cache_attribute
ownership_ids = msg.ownership_ids # local_cache_attribute
if authorization: if authorization:
if not authorization.check_get_ownership_proof(msg): if not authorization.check_get_ownership_proof(msg):
raise wire.ProcessError("Unauthorized operation") raise ProcessError("Unauthorized operation")
else: else:
await validate_path( await validate_path(
ctx, ctx,
@ -37,57 +44,57 @@ async def get_ownership_proof(
validate_path_against_script_type(coin, msg), validate_path_against_script_type(coin, msg),
) )
if msg.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES: if script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Invalid script type") raise DataError("Invalid script type")
if msg.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES and not coin.segwit: if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES and not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin") raise DataError("Segwit not enabled on this coin")
if msg.script_type == InputScriptType.SPENDTAPROOT and not coin.taproot: if script_type == InputScriptType.SPENDTAPROOT and not coin.taproot:
raise wire.DataError("Taproot not enabled on this coin") raise DataError("Taproot not enabled on this coin")
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address = addresses.get_address(script_type, coin, node, msg.multisig)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_id = get_identifier(script_pubkey, keychain) ownership_id = get_identifier(script_pubkey, keychain)
# If the scriptPubKey is multisig, then the caller has to provide # If the scriptPubKey is multisig, then the caller has to provide
# ownership IDs, otherwise providing an ID is optional. # ownership IDs, otherwise providing an ID is optional.
if msg.multisig: if msg.multisig:
if ownership_id not in msg.ownership_ids: if ownership_id not in ownership_ids:
raise wire.DataError("Missing ownership identifier") raise DataError("Missing ownership identifier")
elif msg.ownership_ids: elif ownership_ids:
if msg.ownership_ids != [ownership_id]: if ownership_ids != [ownership_id]:
raise wire.DataError("Invalid ownership identifier") raise DataError("Invalid ownership identifier")
else: else:
msg.ownership_ids = [ownership_id] ownership_ids = [ownership_id]
# In order to set the "user confirmation" bit in the proof, the user must actually confirm. # In order to set the "user confirmation" bit in the proof, the user must actually confirm.
if msg.user_confirmation and not authorization: if msg.user_confirmation and not authorization:
await confirm_action( await confirm_action(
ctx, ctx,
"confirm_ownership_proof", "confirm_ownership_proof",
title="Proof of ownership", "Proof of ownership",
description="Do you want to create a proof of ownership?", description="Do you want to create a proof of ownership?",
) )
if msg.commitment_data: if msg.commitment_data:
await confirm_blob( await confirm_blob(
ctx, ctx,
"confirm_ownership_proof", "confirm_ownership_proof",
title="Proof of ownership", "Proof of ownership",
description="Commitment data:", msg.commitment_data,
data=msg.commitment_data, "Commitment data:",
icon=ui.ICON_CONFIG, icon=ui.ICON_CONFIG,
icon_color=ui.ORANGE_ICON, icon_color=ui.ORANGE_ICON,
) )
ownership_proof, signature = generate_proof( ownership_proof, signature = generate_proof(
node, node,
msg.script_type, script_type,
msg.multisig, msg.multisig,
coin, coin,
msg.user_confirmation, msg.user_confirmation,
msg.ownership_ids, ownership_ids,
script_pubkey, script_pubkey,
msg.commitment_data, msg.commitment_data,
) )

View File

@ -1,37 +1,41 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.enums import InputScriptType
from trezor.messages import HDNodeType, PublicKey, UnlockPath
from apps.common import coininfo, paths
from apps.common.keychain import FORBIDDEN_KEY_PATH, get_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetPublicKey from trezor.messages import GetPublicKey, PublicKey
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
async def get_public_key( async def get_public_key(
ctx: wire.Context, msg: GetPublicKey, auth_msg: MessageType | None = None ctx: Context, msg: GetPublicKey, auth_msg: MessageType | None = None
) -> PublicKey: ) -> PublicKey:
from trezor import wire
from trezor.enums import InputScriptType
from trezor.messages import HDNodeType, PublicKey, UnlockPath
from apps.common import coininfo, paths
from apps.common.keychain import FORBIDDEN_KEY_PATH, get_keychain
coin_name = msg.coin_name or "Bitcoin" coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or InputScriptType.SPENDADDRESS script_type = msg.script_type or InputScriptType.SPENDADDRESS
coin = coininfo.by_name(coin_name) coin = coininfo.by_name(coin_name)
curve_name = msg.ecdsa_curve_name or coin.curve_name curve_name = msg.ecdsa_curve_name or coin.curve_name
address_n = msg.address_n # local_cache_attribute
ignore_xpub_magic = msg.ignore_xpub_magic # local_cache_attribute
xpub_magic = coin.xpub_magic # local_cache_attribute
if msg.address_n and msg.address_n[0] == paths.SLIP25_PURPOSE: if address_n and address_n[0] == paths.SLIP25_PURPOSE:
# UnlockPath is required to access SLIP25 paths. # UnlockPath is required to access SLIP25 paths.
if not UnlockPath.is_type_of(auth_msg): if not UnlockPath.is_type_of(auth_msg):
raise FORBIDDEN_KEY_PATH raise FORBIDDEN_KEY_PATH
# Verify that the desired path lies in the unlocked subtree. # Verify that the desired path lies in the unlocked subtree.
if auth_msg.address_n != msg.address_n[: len(auth_msg.address_n)]: if auth_msg.address_n != address_n[: len(auth_msg.address_n)]:
raise FORBIDDEN_KEY_PATH raise FORBIDDEN_KEY_PATH
keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema]) keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema])
node = keychain.derive(msg.address_n) node = keychain.derive(address_n)
if ( if (
script_type script_type
@ -40,26 +44,26 @@ async def get_public_key(
InputScriptType.SPENDMULTISIG, InputScriptType.SPENDMULTISIG,
InputScriptType.SPENDTAPROOT, InputScriptType.SPENDTAPROOT,
) )
and coin.xpub_magic is not None and xpub_magic is not None
): ):
node_xpub = node.serialize_public(coin.xpub_magic) node_xpub = node.serialize_public(xpub_magic)
elif ( elif (
coin.segwit coin.segwit
and script_type == InputScriptType.SPENDP2SHWITNESS and script_type == InputScriptType.SPENDP2SHWITNESS
and (msg.ignore_xpub_magic or coin.xpub_magic_segwit_p2sh is not None) and (ignore_xpub_magic or coin.xpub_magic_segwit_p2sh is not None)
): ):
assert coin.xpub_magic_segwit_p2sh is not None assert coin.xpub_magic_segwit_p2sh is not None
node_xpub = node.serialize_public( node_xpub = node.serialize_public(
coin.xpub_magic if msg.ignore_xpub_magic else coin.xpub_magic_segwit_p2sh xpub_magic if ignore_xpub_magic else coin.xpub_magic_segwit_p2sh
) )
elif ( elif (
coin.segwit coin.segwit
and script_type == InputScriptType.SPENDWITNESS and script_type == InputScriptType.SPENDWITNESS
and (msg.ignore_xpub_magic or coin.xpub_magic_segwit_native is not None) and (ignore_xpub_magic or coin.xpub_magic_segwit_native is not None)
): ):
assert coin.xpub_magic_segwit_native is not None assert coin.xpub_magic_segwit_native is not None
node_xpub = node.serialize_public( node_xpub = node.serialize_public(
coin.xpub_magic if msg.ignore_xpub_magic else coin.xpub_magic_segwit_native xpub_magic if ignore_xpub_magic else coin.xpub_magic_segwit_native
) )
else: else:
raise wire.DataError("Invalid combination of coin and script_type") raise wire.DataError("Invalid combination of coin and script_type")

View File

@ -1,13 +1,8 @@
import gc
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor.messages import AuthorizeCoinJoin
from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx, UnlockPath
from apps.common import coininfo
from apps.common.keychain import get_keychain
from apps.common.paths import PATTERN_BIP44, PathSchema from apps.common.paths import PATTERN_BIP44, PathSchema
from . import authorization from . import authorization
@ -18,17 +13,22 @@ if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
from trezor.enums import InputScriptType
from trezor.messages import ( from trezor.messages import (
GetAddress, GetAddress,
GetOwnershipId, GetOwnershipId,
GetPublicKey, GetPublicKey,
SignMessage, SignMessage,
VerifyMessage, VerifyMessage,
GetOwnershipProof,
SignTx,
) )
from apps.common.keychain import Keychain, MsgOut, Handler from apps.common.keychain import Keychain, MsgOut, Handler
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
from apps.common import coininfo
BitcoinMessage = ( BitcoinMessage = (
AuthorizeCoinJoin AuthorizeCoinJoin
@ -99,7 +99,11 @@ def validate_path_against_script_type(
script_type: InputScriptType | None = None, script_type: InputScriptType | None = None,
multisig: bool = False, multisig: bool = False,
) -> bool: ) -> bool:
from trezor.enums import InputScriptType
patterns = [] patterns = []
append = patterns.append # local_cache_attribute
slip44 = coin.slip44 # local_cache_attribute
if msg is not None: if msg is not None:
assert address_n is None and script_type is None assert address_n is None and script_type is None
@ -111,58 +115,60 @@ def validate_path_against_script_type(
assert address_n is not None and script_type is not None assert address_n is not None and script_type is not None
if script_type == InputScriptType.SPENDADDRESS and not multisig: if script_type == InputScriptType.SPENDADDRESS and not multisig:
patterns.append(PATTERN_BIP44) append(PATTERN_BIP44)
if coin.slip44 == _SLIP44_BITCOIN: if slip44 == _SLIP44_BITCOIN:
patterns.append(PATTERN_GREENADDRESS_A) append(PATTERN_GREENADDRESS_A)
patterns.append(PATTERN_GREENADDRESS_B) append(PATTERN_GREENADDRESS_B)
elif ( elif (
script_type in (InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG) script_type in (InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG)
and multisig and multisig
): ):
patterns.append(PATTERN_BIP48_RAW) append(PATTERN_BIP48_RAW)
if coin.slip44 == _SLIP44_BITCOIN or ( if slip44 == _SLIP44_BITCOIN or (
coin.fork_id is not None and coin.slip44 != _SLIP44_TESTNET coin.fork_id is not None and slip44 != _SLIP44_TESTNET
): ):
patterns.append(PATTERN_BIP45) append(PATTERN_BIP45)
if coin.slip44 == _SLIP44_BITCOIN: if slip44 == _SLIP44_BITCOIN:
patterns.append(PATTERN_GREENADDRESS_A) append(PATTERN_GREENADDRESS_A)
patterns.append(PATTERN_GREENADDRESS_B) append(PATTERN_GREENADDRESS_B)
if coin.coin_name in BITCOIN_NAMES: if coin.coin_name in BITCOIN_NAMES:
patterns.append(PATTERN_UNCHAINED_HARDENED) append(PATTERN_UNCHAINED_HARDENED)
patterns.append(PATTERN_UNCHAINED_UNHARDENED) append(PATTERN_UNCHAINED_UNHARDENED)
patterns.append(PATTERN_UNCHAINED_DEPRECATED) append(PATTERN_UNCHAINED_DEPRECATED)
elif coin.segwit and script_type == InputScriptType.SPENDP2SHWITNESS: elif coin.segwit and script_type == InputScriptType.SPENDP2SHWITNESS:
patterns.append(PATTERN_BIP49) append(PATTERN_BIP49)
if multisig: if multisig:
patterns.append(PATTERN_BIP48_P2SHSEGWIT) append(PATTERN_BIP48_P2SHSEGWIT)
if coin.slip44 == _SLIP44_BITCOIN: if slip44 == _SLIP44_BITCOIN:
patterns.append(PATTERN_GREENADDRESS_A) append(PATTERN_GREENADDRESS_A)
patterns.append(PATTERN_GREENADDRESS_B) append(PATTERN_GREENADDRESS_B)
if coin.coin_name in BITCOIN_NAMES: if coin.coin_name in BITCOIN_NAMES:
patterns.append(PATTERN_CASA) append(PATTERN_CASA)
elif coin.segwit and script_type == InputScriptType.SPENDWITNESS: elif coin.segwit and script_type == InputScriptType.SPENDWITNESS:
patterns.append(PATTERN_BIP84) append(PATTERN_BIP84)
if multisig: if multisig:
patterns.append(PATTERN_BIP48_SEGWIT) append(PATTERN_BIP48_SEGWIT)
if coin.slip44 == _SLIP44_BITCOIN: if slip44 == _SLIP44_BITCOIN:
patterns.append(PATTERN_GREENADDRESS_A) append(PATTERN_GREENADDRESS_A)
patterns.append(PATTERN_GREENADDRESS_B) append(PATTERN_GREENADDRESS_B)
elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT: elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT:
patterns.append(PATTERN_BIP86) append(PATTERN_BIP86)
patterns.append(PATTERN_SLIP25_TAPROOT) append(PATTERN_SLIP25_TAPROOT)
return any( return any(
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
) )
def get_schemas_for_coin( def _get_schemas_for_coin(
coin: coininfo.CoinInfo, unlock_schemas: Iterable[PathSchema] = () coin: coininfo.CoinInfo, unlock_schemas: Iterable[PathSchema] = ()
) -> Iterable[PathSchema]: ) -> Iterable[PathSchema]:
import gc
# basic patterns # basic patterns
patterns = [ patterns = [
PATTERN_BIP44, PATTERN_BIP44,
@ -237,7 +243,10 @@ def get_schemas_from_patterns(
return schemas return schemas
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: def _get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
from apps.common import coininfo
from trezor import wire
if coin_name is None: if coin_name is None:
coin_name = "Bitcoin" coin_name = "Bitcoin"
@ -247,12 +256,14 @@ def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
raise wire.DataError("Unsupported coin type") raise wire.DataError("Unsupported coin type")
async def get_keychain_for_coin( async def _get_keychain_for_coin(
ctx: wire.Context, ctx: Context,
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
unlock_schemas: Iterable[PathSchema] = (), unlock_schemas: Iterable[PathSchema] = (),
) -> Keychain: ) -> Keychain:
schemas = get_schemas_for_coin(coin, unlock_schemas) from apps.common.keychain import get_keychain
schemas = _get_schemas_for_coin(coin, unlock_schemas)
slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]] slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]]
keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces) keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces)
return keychain return keychain
@ -265,6 +276,7 @@ def _get_unlock_schemas(
Provides additional keychain schemas that are unlocked by the particular Provides additional keychain schemas that are unlocked by the particular
combination of `msg` and `auth_msg`. combination of `msg` and `auth_msg`.
""" """
from trezor.messages import GetOwnershipProof, SignTx, UnlockPath
if AuthorizeCoinJoin.is_type_of(msg): if AuthorizeCoinJoin.is_type_of(msg):
# When processing the AuthorizeCoinJoin message, validate_path() always # When processing the AuthorizeCoinJoin message, validate_path() always
@ -298,13 +310,13 @@ def _get_unlock_schemas(
def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper( async def wrapper(
ctx: wire.Context, ctx: Context,
msg: MsgIn, msg: MsgIn,
auth_msg: MessageType | None = None, auth_msg: MessageType | None = None,
) -> MsgOut: ) -> MsgOut:
coin = get_coin_by_name(msg.coin_name) coin = _get_coin_by_name(msg.coin_name)
unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin) unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin)
keychain = await get_keychain_for_coin(ctx, coin, unlock_schemas) keychain = await _get_keychain_for_coin(ctx, coin, unlock_schemas)
if AuthorizeCoinJoin.is_type_of(auth_msg): if AuthorizeCoinJoin.is_type_of(auth_msg):
auth_obj = authorization.from_cached_message(auth_msg) auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj) return await func(ctx, msg, keychain, coin, auth_obj)

View File

@ -1,19 +1,17 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor.wire import DataError
from trezor.crypto import bip32
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from apps.common import paths
from .writers import write_bytes_fixed, write_uint32
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import HDNodeType, MultisigRedeemScriptType from trezor.messages import HDNodeType, MultisigRedeemScriptType
from apps.common import paths
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .writers import write_bytes_fixed, write_uint32
if multisig.nodes: if multisig.nodes:
pubnodes = multisig.nodes pubnodes = multisig.nodes
else: else:
@ -22,11 +20,11 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
n = len(pubnodes) n = len(pubnodes)
if n < 1 or n > 15 or m < 1 or m > 15: if n < 1 or n > 15 or m < 1 or m > 15:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
for d in pubnodes: for d in pubnodes:
if len(d.public_key) != 33 or len(d.chain_code) != 32: if len(d.public_key) != 33 or len(d.chain_code) != 32:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
# casting to bytes(), sorting on bytearray() is not supported in MicroPython # casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubnodes = sorted(pubnodes, key=lambda n: bytes(n.public_key)) pubnodes = sorted(pubnodes, key=lambda n: bytes(n.public_key))
@ -45,11 +43,13 @@ def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
def validate_multisig(multisig: MultisigRedeemScriptType) -> None: def validate_multisig(multisig: MultisigRedeemScriptType) -> None:
from apps.common import paths
if any(paths.is_hardened(n) for n in multisig.address_n): if any(paths.is_hardened(n) for n in multisig.address_n):
raise wire.DataError("Cannot perform hardened derivation from XPUB") raise DataError("Cannot perform hardened derivation from XPUB")
for hd in multisig.pubkeys: for hd in multisig.pubkeys:
if any(paths.is_hardened(n) for n in hd.address_n): if any(paths.is_hardened(n) for n in hd.address_n):
raise wire.DataError("Cannot perform hardened derivation from XPUB") raise DataError("Cannot perform hardened derivation from XPUB")
def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int: def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int:
@ -62,10 +62,12 @@ def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) ->
for i, hd in enumerate(multisig.pubkeys): for i, hd in enumerate(multisig.pubkeys):
if multisig_get_pubkey(hd.node, hd.address_n) == pubkey: if multisig_get_pubkey(hd.node, hd.address_n) == pubkey:
return i return i
raise wire.DataError("Pubkey not found in multisig script") raise DataError("Pubkey not found in multisig script")
def multisig_get_pubkey(n: HDNodeType, p: paths.Bip32Path) -> bytes: def multisig_get_pubkey(n: HDNodeType, p: paths.Bip32Path) -> bytes:
from trezor.crypto import bip32
node = bip32.HDNode( node = bip32.HDNode(
depth=n.depth, depth=n.depth,
fingerprint=n.fingerprint, fingerprint=n.fingerprint,

View File

@ -1,28 +1,22 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.crypto import bip32, hmac
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.wire import DataError
from apps.bitcoin.writers import ( from apps.bitcoin.writers import write_bytes_prefixed
write_bytes_fixed,
write_bytes_prefixed,
write_compact_size,
write_uint8,
)
from apps.common.keychain import Keychain
from apps.common.readers import read_compact_size from apps.common.readers import read_compact_size
from . import common from .scripts import read_bip322_signature_proof
from .scripts import read_bip322_signature_proof, write_bip322_signature_proof
from .verification import SignatureVerifier
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import MultisigRedeemScriptType from trezor.messages import MultisigRedeemScriptType
from trezor.enums import InputScriptType
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from trezor.crypto import bip32
from apps.common.keychain import Keychain
# This module implements the SLIP-0019 proof of ownership format, see # This module implements the SLIP-0019 proof of ownership format, see
# https://github.com/satoshilabs/slips/blob/master/slip-0019.md. # https://github.com/satoshilabs/slips/blob/master/slip-0019.md.
@ -44,6 +38,15 @@ def generate_proof(
script_pubkey: bytes, script_pubkey: bytes,
commitment_data: bytes, commitment_data: bytes,
) -> tuple[bytes, bytes]: ) -> tuple[bytes, bytes]:
from trezor.enums import InputScriptType
from apps.bitcoin.writers import (
write_bytes_fixed,
write_compact_size,
write_uint8,
)
from .scripts import write_bip322_signature_proof
from . import common
flags = 0 flags = 0
if user_confirmed: if user_confirmed:
flags |= _FLAG_USER_CONFIRMED flags |= _FLAG_USER_CONFIRMED
@ -69,7 +72,7 @@ def generate_proof(
elif script_type == InputScriptType.SPENDTAPROOT: elif script_type == InputScriptType.SPENDTAPROOT:
signature = common.bip340_sign(node, sighash.get_digest()) signature = common.bip340_sign(node, sighash.get_digest())
else: else:
raise wire.DataError("Unsupported script type.") raise DataError("Unsupported script type.")
public_key = node.public_key() public_key = node.public_key()
write_bip322_signature_proof( write_bip322_signature_proof(
proof, script_type, multisig, coin, public_key, signature proof, script_type, multisig, coin, public_key, signature
@ -85,14 +88,16 @@ def verify_nonownership(
keychain: Keychain, keychain: Keychain,
coin: CoinInfo, coin: CoinInfo,
) -> bool: ) -> bool:
from .verification import SignatureVerifier
try: try:
r = utils.BufferReader(proof) r = utils.BufferReader(proof)
if r.read_memoryview(4) != _VERSION_MAGIC: if r.read_memoryview(4) != _VERSION_MAGIC:
raise wire.DataError("Unknown format of proof of ownership") raise DataError("Unknown format of proof of ownership")
flags = r.get() flags = r.get()
if flags & 0b1111_1110: if flags & 0b1111_1110:
raise wire.DataError("Unknown flags in proof of ownership") raise DataError("Unknown flags in proof of ownership")
# Determine whether our ownership ID appears in the proof. # Determine whether our ownership ID appears in the proof.
id_count = read_compact_size(r) id_count = read_compact_size(r)
@ -119,7 +124,7 @@ def verify_nonownership(
verifier = SignatureVerifier(script_pubkey, script_sig, witness, coin) verifier = SignatureVerifier(script_pubkey, script_sig, witness, coin)
verifier.verify(sighash.get_digest()) verifier.verify(sighash.get_digest())
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid proof of ownership") raise DataError("Invalid proof of ownership")
return not_owned return not_owned
@ -128,11 +133,11 @@ def read_scriptsig_witness(ownership_proof: bytes) -> tuple[memoryview, memoryvi
try: try:
r = utils.BufferReader(ownership_proof) r = utils.BufferReader(ownership_proof)
if r.read_memoryview(4) != _VERSION_MAGIC: if r.read_memoryview(4) != _VERSION_MAGIC:
raise wire.DataError("Unknown format of proof of ownership") raise DataError("Unknown format of proof of ownership")
flags = r.get() flags = r.get()
if flags & 0b1111_1110: if flags & 0b1111_1110:
raise wire.DataError("Unknown flags in proof of ownership") raise DataError("Unknown flags in proof of ownership")
# Skip ownership IDs. # Skip ownership IDs.
id_count = read_compact_size(r) id_count = read_compact_size(r)
@ -141,10 +146,12 @@ def read_scriptsig_witness(ownership_proof: bytes) -> tuple[memoryview, memoryvi
return read_bip322_signature_proof(r) return read_bip322_signature_proof(r)
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid proof of ownership") raise DataError("Invalid proof of ownership")
def get_identifier(script_pubkey: bytes, keychain: Keychain) -> bytes: def get_identifier(script_pubkey: bytes, keychain: Keychain) -> bytes:
from trezor.crypto import hmac
# k = Key(m/"SLIP-0019"/"Ownership identification key") # k = Key(m/"SLIP-0019"/"Ownership identification key")
node = keychain.derive_slip21(_OWNERSHIP_ID_KEY_PATH) node = keychain.derive_slip21(_OWNERSHIP_ID_KEY_PATH)

View File

@ -1,27 +1,32 @@
from trezor.utils import BufferReader from typing import TYPE_CHECKING
from apps.common.readers import read_compact_size if TYPE_CHECKING:
from trezor.utils import BufferReader
def read_memoryview_prefixed(r: BufferReader) -> memoryview: def read_memoryview_prefixed(r: BufferReader) -> memoryview:
from apps.common.readers import read_compact_size
n = read_compact_size(r) n = read_compact_size(r)
return r.read_memoryview(n) return r.read_memoryview(n)
def read_op_push(r: BufferReader) -> int: def read_op_push(r: BufferReader) -> int:
prefix = r.get() get = r.get # local_cache_attribute
prefix = get()
if prefix < 0x4C: if prefix < 0x4C:
n = prefix n = prefix
elif prefix == 0x4C: elif prefix == 0x4C:
n = r.get() n = get()
elif prefix == 0x4D: elif prefix == 0x4D:
n = r.get() n = get()
n += r.get() << 8 n += get() << 8
elif prefix == 0x4E: elif prefix == 0x4E:
n = r.get() n = get()
n += r.get() << 8 n += get() << 8
n += r.get() << 16 n += get() << 16
n += r.get() << 24 n += get() << 24
else: else:
raise ValueError raise ValueError
return n return n

View File

@ -1,24 +1,18 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.crypto import base58, cashaddr
from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.utils import BufferReader, empty_bytearray
from trezor.wire import DataError
from apps.common import address_type
from apps.common.readers import read_compact_size from apps.common.readers import read_compact_size
from apps.common.writers import write_compact_size from apps.common.writers import write_compact_size
from . import common from . import common
from .common import SigHashType from .common import SigHashType
from .multisig import ( from .multisig import multisig_get_pubkeys, multisig_pubkey_index
multisig_get_pubkey_count,
multisig_get_pubkeys,
multisig_pubkey_index,
)
from .readers import read_memoryview_prefixed, read_op_push from .readers import read_memoryview_prefixed, read_op_push
from .writers import ( from .writers import (
op_push_length,
write_bytes_fixed, write_bytes_fixed,
write_bytes_prefixed, write_bytes_prefixed,
write_bytes_unchecked, write_bytes_unchecked,
@ -44,10 +38,15 @@ def write_input_script_prefixed(
pubkey: bytes, pubkey: bytes,
signature: bytes, signature: bytes,
) -> None: ) -> None:
if script_type == InputScriptType.SPENDADDRESS: from trezor.crypto.hashlib import sha256
from trezor import wire
IST = InputScriptType # local_cache_global
if script_type == IST.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
write_input_script_p2pkh_or_p2sh_prefixed(w, pubkey, signature, sighash_type) write_input_script_p2pkh_or_p2sh_prefixed(w, pubkey, signature, sighash_type)
elif script_type == InputScriptType.SPENDP2SHWITNESS: elif script_type == IST.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh # p2wpkh or p2wsh using p2sh
if multisig is not None: if multisig is not None:
@ -63,15 +62,15 @@ def write_input_script_prefixed(
write_input_script_p2wpkh_in_p2sh( write_input_script_p2wpkh_in_p2sh(
w, common.ecdsa_hash_pubkey(pubkey, coin), prefixed=True w, common.ecdsa_hash_pubkey(pubkey, coin), prefixed=True
) )
elif script_type in (InputScriptType.SPENDWITNESS, InputScriptType.SPENDTAPROOT): elif script_type in (IST.SPENDWITNESS, IST.SPENDTAPROOT):
# native p2wpkh or p2wsh or p2tr # native p2wpkh or p2wsh or p2tr
script_sig = input_script_native_segwit() script_sig = _input_script_native_segwit()
write_bytes_prefixed(w, script_sig) write_bytes_prefixed(w, script_sig)
elif script_type == InputScriptType.SPENDMULTISIG: elif script_type == IST.SPENDMULTISIG:
# p2sh multisig # p2sh multisig
assert multisig is not None # checked in sanitize_tx_input assert multisig is not None # checked in _sanitize_tx_input
signature_index = multisig_pubkey_index(multisig, pubkey) signature_index = multisig_pubkey_index(multisig, pubkey)
write_input_script_multisig_prefixed( _write_input_script_multisig_prefixed(
w, multisig, signature, signature_index, sighash_type, coin w, multisig, signature, signature_index, sighash_type, coin
) )
else: else:
@ -79,6 +78,9 @@ def write_input_script_prefixed(
def output_derive_script(address: str, coin: CoinInfo) -> bytes: def output_derive_script(address: str, coin: CoinInfo) -> bytes:
from trezor.crypto import base58, cashaddr
from apps.common import address_type
if coin.bech32_prefix and address.startswith(coin.bech32_prefix): if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh or p2tr # p2wpkh or p2wsh or p2tr
witver, witprog = common.decode_bech32_address(coin.bech32_prefix, address) witver, witprog = common.decode_bech32_address(coin.bech32_prefix, address)
@ -96,13 +98,13 @@ def output_derive_script(address: str, coin: CoinInfo) -> bytes:
elif version == cashaddr.ADDRESS_TYPE_P2SH: elif version == cashaddr.ADDRESS_TYPE_P2SH:
version = coin.address_type_p2sh version = coin.address_type_p2sh
else: else:
raise wire.DataError("Unknown cashaddr address type") raise DataError("Unknown cashaddr address type")
raw_address = bytes([version]) + data raw_address = bytes([version]) + data
else: else:
try: try:
raw_address = base58.decode_check(address, coin.b58_hash) raw_address = base58.decode_check(address, coin.b58_hash)
except ValueError: except ValueError:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
if address_type.check(coin.address_type, raw_address): if address_type.check(coin.address_type, raw_address):
# p2pkh # p2pkh
@ -115,7 +117,7 @@ def output_derive_script(address: str, coin: CoinInfo) -> bytes:
script = output_script_p2sh(scripthash) script = output_script_p2sh(scripthash)
return script return script
raise wire.DataError("Invalid address type") raise DataError("Invalid address type")
# see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification # see https://github.com/bitcoin/bips/blob/master/bip-0143.mediawiki#specification
@ -144,7 +146,7 @@ def write_bip143_script_code_prefixed(
w, common.ecdsa_hash_pubkey(public_keys[0], coin), prefixed=True w, common.ecdsa_hash_pubkey(public_keys[0], coin), prefixed=True
) )
else: else:
raise wire.DataError("Unknown input script type for bip143 script code") raise DataError("Unknown input script type for bip143 script code")
# P2PKH, P2SH # P2PKH, P2SH
@ -164,7 +166,7 @@ def parse_input_script_p2pkh(
script_sig: bytes, script_sig: bytes,
) -> tuple[memoryview, memoryview, SigHashType]: ) -> tuple[memoryview, memoryview, SigHashType]:
try: try:
r = utils.BufferReader(script_sig) r = BufferReader(script_sig)
n = read_op_push(r) n = read_op_push(r)
signature = r.read_memoryview(n - 1) signature = r.read_memoryview(n - 1)
sighash_type = SigHashType.from_int(r.get()) sighash_type = SigHashType.from_int(r.get())
@ -174,7 +176,7 @@ def parse_input_script_p2pkh(
if len(pubkey) != n: if len(pubkey) != n:
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid scriptSig.") raise DataError("Invalid scriptSig.")
return pubkey, signature, sighash_type return pubkey, signature, sighash_type
@ -182,18 +184,20 @@ def parse_input_script_p2pkh(
def write_output_script_p2pkh( def write_output_script_p2pkh(
w: Writer, pubkeyhash: bytes, prefixed: bool = False w: Writer, pubkeyhash: bytes, prefixed: bool = False
) -> None: ) -> None:
append = w.append # local_cache_attribute
if prefixed: if prefixed:
write_compact_size(w, 25) write_compact_size(w, 25)
w.append(0x76) # OP_DUP append(0x76) # OP_DUP
w.append(0xA9) # OP_HASH160 append(0xA9) # OP_HASH160
w.append(0x14) # OP_DATA_20 append(0x14) # OP_DATA_20
write_bytes_fixed(w, pubkeyhash, 20) write_bytes_fixed(w, pubkeyhash, 20)
w.append(0x88) # OP_EQUALVERIFY append(0x88) # OP_EQUALVERIFY
w.append(0xAC) # OP_CHECKSIG append(0xAC) # OP_CHECKSIG
def output_script_p2pkh(pubkeyhash: bytes) -> bytearray: def output_script_p2pkh(pubkeyhash: bytes) -> bytearray:
s = utils.empty_bytearray(25) s = empty_bytearray(25)
write_output_script_p2pkh(s, pubkeyhash) write_output_script_p2pkh(s, pubkeyhash)
return s return s
@ -225,7 +229,7 @@ def output_script_p2sh(scripthash: bytes) -> bytearray:
# https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#script-validation-rules # https://github.com/bitcoin/bips/blob/master/bip-0341.mediawiki#script-validation-rules
def input_script_native_segwit() -> bytearray: def _input_script_native_segwit() -> bytearray:
# Completely replaced by the witness and therefore empty. # Completely replaced by the witness and therefore empty.
return bytearray(0) return bytearray(0)
@ -238,7 +242,7 @@ def output_script_native_segwit(witver: int, witprog: bytes) -> bytearray:
length = len(witprog) length = len(witprog)
utils.ensure((length == 20 and witver == 0) or length == 32) utils.ensure((length == 20 and witver == 0) or length == 32)
w = utils.empty_bytearray(2 + length) w = empty_bytearray(2 + length)
w.append(witver + 0x50 if witver else 0) # witness version byte (OP_witver) w.append(witver + 0x50 if witver else 0) # witness version byte (OP_witver)
w.append(length) # witness program length is 20 (P2WPKH) or 32 (P2WSH, P2TR) bytes w.append(length) # witness program length is 20 (P2WPKH) or 32 (P2WSH, P2TR) bytes
write_bytes_fixed(w, witprog, length) write_bytes_fixed(w, witprog, length)
@ -248,7 +252,7 @@ def output_script_native_segwit(witver: int, witprog: bytes) -> bytearray:
def parse_output_script_p2tr(script_pubkey: bytes) -> memoryview: def parse_output_script_p2tr(script_pubkey: bytes) -> memoryview:
# 51 20 <32-byte-taproot-output-key> # 51 20 <32-byte-taproot-output-key>
try: try:
r = utils.BufferReader(script_pubkey) r = BufferReader(script_pubkey)
if r.get() != common.OP_1: if r.get() != common.OP_1:
# P2TR should be SegWit version 1 # P2TR should be SegWit version 1
@ -262,7 +266,7 @@ def parse_output_script_p2tr(script_pubkey: bytes) -> memoryview:
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid scriptPubKey.") raise DataError("Invalid scriptPubKey.")
return pubkey return pubkey
@ -325,7 +329,7 @@ def write_witness_p2wpkh(
def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, SigHashType]: def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, SigHashType]:
try: try:
r = utils.BufferReader(witness) r = BufferReader(witness)
if r.get() != 2: if r.get() != 2:
# num of stack items, in P2WPKH it's always 2 # num of stack items, in P2WPKH it's always 2
@ -339,7 +343,7 @@ def parse_witness_p2wpkh(witness: bytes) -> tuple[memoryview, memoryview, SigHas
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid witness.") raise DataError("Invalid witness.")
return pubkey, signature, sighash_type return pubkey, signature, sighash_type
@ -351,6 +355,8 @@ def write_witness_multisig(
signature_index: int, signature_index: int,
sighash_type: SigHashType, sighash_type: SigHashType,
) -> None: ) -> None:
from .multisig import multisig_get_pubkey_count
# get other signatures, stretch with empty bytes to the number of the pubkeys # get other signatures, stretch with empty bytes to the number of the pubkeys
signatures = multisig.signatures + [b""] * ( signatures = multisig.signatures + [b""] * (
multisig_get_pubkey_count(multisig) - len(multisig.signatures) multisig_get_pubkey_count(multisig) - len(multisig.signatures)
@ -358,7 +364,7 @@ def write_witness_multisig(
# fill in our signature # fill in our signature
if signatures[signature_index]: if signatures[signature_index]:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
signatures[signature_index] = signature signatures[signature_index] = signature
# witness program + signatures + redeem script # witness program + signatures + redeem script
@ -383,7 +389,7 @@ def parse_witness_multisig(
witness: bytes, witness: bytes,
) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]: ) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
try: try:
r = utils.BufferReader(witness) r = BufferReader(witness)
# Get number of witness stack items. # Get number of witness stack items.
item_count = read_compact_size(r) item_count = read_compact_size(r)
@ -403,7 +409,7 @@ def parse_witness_multisig(
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid witness.") raise DataError("Invalid witness.")
return script, signatures return script, signatures
@ -420,7 +426,7 @@ def write_witness_p2tr(w: Writer, signature: bytes, sighash_type: SigHashType) -
def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]: def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]:
try: try:
r = utils.BufferReader(witness) r = BufferReader(witness)
if r.get() != 1: # Number of stack items. if r.get() != 1: # Number of stack items.
# Only Taproot key path spending without annex is supported. # Only Taproot key path spending without annex is supported.
@ -439,7 +445,7 @@ def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]:
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid witness.") raise DataError("Invalid witness.")
return signature, sighash_type return signature, sighash_type
@ -450,7 +456,7 @@ def parse_witness_p2tr(witness: bytes) -> tuple[memoryview, SigHashType]:
# Used either as P2SH, P2WSH, or P2WSH nested in P2SH. # Used either as P2SH, P2WSH, or P2WSH nested in P2SH.
def write_input_script_multisig_prefixed( def _write_input_script_multisig_prefixed(
w: Writer, w: Writer,
multisig: MultisigRedeemScriptType, multisig: MultisigRedeemScriptType,
signature: bytes, signature: bytes,
@ -458,9 +464,11 @@ def write_input_script_multisig_prefixed(
sighash_type: SigHashType, sighash_type: SigHashType,
coin: CoinInfo, coin: CoinInfo,
) -> None: ) -> None:
from .writers import op_push_length
signatures = multisig.signatures # other signatures signatures = multisig.signatures # other signatures
if len(signatures[signature_index]) > 0: if len(signatures[signature_index]) > 0:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
signatures[signature_index] = signature # our signature signatures[signature_index] = signature # our signature
# length of the redeem script # length of the redeem script
@ -493,7 +501,7 @@ def parse_input_script_multisig(
script_sig: bytes, script_sig: bytes,
) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]: ) -> tuple[memoryview, list[tuple[memoryview, SigHashType]]]:
try: try:
r = utils.BufferReader(script_sig) r = BufferReader(script_sig)
# Skip over OP_FALSE, which is due to the old OP_CHECKMULTISIG bug. # Skip over OP_FALSE, which is due to the old OP_CHECKMULTISIG bug.
if r.get() != 0: if r.get() != 0:
@ -511,13 +519,13 @@ def parse_input_script_multisig(
if len(script) != n: if len(script) != n:
raise ValueError raise ValueError
except (ValueError, EOFError): except (ValueError, EOFError):
raise wire.DataError("Invalid scriptSig.") raise DataError("Invalid scriptSig.")
return script, signatures return script, signatures
def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray: def output_script_multisig(pubkeys: list[bytes], m: int) -> bytearray:
w = utils.empty_bytearray(output_script_multisig_length(pubkeys, m)) w = empty_bytearray(output_script_multisig_length(pubkeys, m))
write_output_script_multisig(w, pubkeys, m) write_output_script_multisig(w, pubkeys, m)
return w return w
@ -530,10 +538,10 @@ def write_output_script_multisig(
) -> None: ) -> None:
n = len(pubkeys) n = len(pubkeys)
if n < 1 or n > 15 or m < 1 or m > 15 or m > n: if n < 1 or n > 15 or m < 1 or m > 15 or m > n:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
for pubkey in pubkeys: for pubkey in pubkeys:
if len(pubkey) != 33: if len(pubkey) != 33:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
if prefixed: if prefixed:
write_compact_size(w, output_script_multisig_length(pubkeys, m)) write_compact_size(w, output_script_multisig_length(pubkeys, m))
@ -551,7 +559,7 @@ def output_script_multisig_length(pubkeys: Sequence[bytes | memoryview], m: int)
def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]: def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
try: try:
r = utils.BufferReader(script) r = BufferReader(script)
threshold = r.get() - 0x50 threshold = r.get() - 0x50
pubkey_count = script[-2] - 0x50 pubkey_count = script[-2] - 0x50
@ -577,7 +585,7 @@ def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
raise ValueError raise ValueError
except (ValueError, IndexError, EOFError): except (ValueError, IndexError, EOFError):
raise wire.DataError("Invalid multisig script") raise DataError("Invalid multisig script")
return public_keys, threshold return public_keys, threshold
@ -587,7 +595,7 @@ def parse_output_script_multisig(script: bytes) -> tuple[list[memoryview], int]:
def output_script_paytoopreturn(data: bytes) -> bytearray: def output_script_paytoopreturn(data: bytes) -> bytearray:
w = utils.empty_bytearray(1 + 5 + len(data)) w = empty_bytearray(1 + 5 + len(data))
w.append(0x6A) # OP_RETURN w.append(0x6A) # OP_RETURN
write_op_push(w, len(data)) write_op_push(w, len(data))
w.extend(data) w.extend(data)
@ -627,7 +635,7 @@ def write_bip322_signature_proof(
w.append(0x00) w.append(0x00)
def read_bip322_signature_proof(r: utils.BufferReader) -> tuple[memoryview, memoryview]: def read_bip322_signature_proof(r: BufferReader) -> tuple[memoryview, memoryview]:
script_sig = read_memoryview_prefixed(r) script_sig = read_memoryview_prefixed(r)
witness = r.read_memoryview() witness = r.read_memoryview()
return script_sig, witness return script_sig, witness

View File

@ -1,27 +1,25 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.crypto import base58 from trezor.crypto import base58
from trezor.crypto.base58 import blake256d_32 from trezor.crypto.base58 import blake256d_32
from trezor.enums import InputScriptType from trezor.wire import DataError
from apps.common.writers import write_bytes_fixed, write_uint64_le
from . import scripts from . import scripts
from .common import SigHashType
from .multisig import multisig_get_pubkeys, multisig_pubkey_index
from .scripts import ( # noqa: F401 from .scripts import ( # noqa: F401
output_script_paytoopreturn, output_script_paytoopreturn,
write_output_script_multisig, write_output_script_multisig,
write_output_script_p2pkh, write_output_script_p2pkh,
) )
from .writers import op_push_length, write_compact_size, write_op_push from .writers import write_compact_size
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import MultisigRedeemScriptType from trezor.messages import MultisigRedeemScriptType
from trezor.enums import InputScriptType
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from .common import SigHashType
from .writers import Writer from .writers import Writer
@ -34,6 +32,10 @@ def write_input_script_prefixed(
pubkey: bytes, pubkey: bytes,
signature: bytes, signature: bytes,
) -> None: ) -> None:
from trezor import wire
from trezor.enums import InputScriptType
from .multisig import multisig_pubkey_index
if script_type == InputScriptType.SPENDADDRESS: if script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh # p2pkh or p2sh
scripts.write_input_script_p2pkh_or_p2sh_prefixed( scripts.write_input_script_p2pkh_or_p2sh_prefixed(
@ -41,16 +43,16 @@ def write_input_script_prefixed(
) )
elif script_type == InputScriptType.SPENDMULTISIG: elif script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig # p2sh multisig
assert multisig is not None # checked in sanitize_tx_input assert multisig is not None # checked in _sanitize_tx_input
signature_index = multisig_pubkey_index(multisig, pubkey) signature_index = multisig_pubkey_index(multisig, pubkey)
write_input_script_multisig_prefixed( _write_input_script_multisig_prefixed(
w, multisig, signature, signature_index, sighash_type, coin w, multisig, signature, signature_index, sighash_type, coin
) )
else: else:
raise wire.ProcessError("Invalid script type") raise wire.ProcessError("Invalid script type")
def write_input_script_multisig_prefixed( def _write_input_script_multisig_prefixed(
w: Writer, w: Writer,
multisig: MultisigRedeemScriptType, multisig: MultisigRedeemScriptType,
signature: bytes, signature: bytes,
@ -58,9 +60,12 @@ def write_input_script_multisig_prefixed(
sighash_type: SigHashType, sighash_type: SigHashType,
coin: CoinInfo, coin: CoinInfo,
) -> None: ) -> None:
from .multisig import multisig_get_pubkeys
from .writers import op_push_length, write_op_push
signatures = multisig.signatures # other signatures signatures = multisig.signatures # other signatures
if len(signatures[signature_index]) > 0: if len(signatures[signature_index]) > 0:
raise wire.DataError("Invalid multisig parameters") raise DataError("Invalid multisig parameters")
signatures[signature_index] = signature # our signature signatures[signature_index] = signature # our signature
# length of the redeem script # length of the redeem script
@ -89,7 +94,7 @@ def output_script_sstxsubmissionpkh(addr: str) -> bytearray:
try: try:
raw_address = base58.decode_check(addr, blake256d_32) raw_address = base58.decode_check(addr, blake256d_32)
except ValueError: except ValueError:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
w = utils.empty_bytearray(26) w = utils.empty_bytearray(26)
w.append(0xBA) # OP_SSTX w.append(0xBA) # OP_SSTX
@ -102,7 +107,7 @@ def output_script_sstxchange(addr: str) -> bytearray:
try: try:
raw_address = base58.decode_check(addr, blake256d_32) raw_address = base58.decode_check(addr, blake256d_32)
except ValueError: except ValueError:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
w = utils.empty_bytearray(26) w = utils.empty_bytearray(26)
w.append(0xBD) # OP_SSTXCHANGE w.append(0xBD) # OP_SSTXCHANGE
@ -128,6 +133,8 @@ def write_output_script_ssgen_prefixed(w: Writer, pkh: bytes) -> None:
# Stake commitment OPRETURN. # Stake commitment OPRETURN.
def sstxcommitment_pkh(pkh: bytes, amount: int) -> bytes: def sstxcommitment_pkh(pkh: bytes, amount: int) -> bytes:
from apps.common.writers import write_bytes_fixed, write_uint64_le
w = utils.empty_bytearray(30) w = utils.empty_bytearray(30)
write_bytes_fixed(w, pkh, 20) write_bytes_fixed(w, pkh, 20)
write_uint64_le(w, amount) write_uint64_le(w, amount)

View File

@ -1,19 +1,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from .keychain import with_keychain
from trezor.crypto.curve import secp256k1
from trezor.enums import InputScriptType
from trezor.messages import MessageSignature
from trezor.ui.layouts import confirm_signverify
from apps.common.paths import validate_path
from apps.common.signverify import decode_message, message_digest
from .addresses import address_short, get_address
from .keychain import validate_path_against_script_type, with_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import SignMessage from trezor.messages import SignMessage, MessageSignature
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -21,8 +12,20 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def sign_message( async def sign_message(
ctx: wire.Context, msg: SignMessage, keychain: Keychain, coin: CoinInfo ctx: Context, msg: SignMessage, keychain: Keychain, coin: CoinInfo
) -> MessageSignature: ) -> MessageSignature:
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.enums import InputScriptType
from trezor.messages import MessageSignature
from trezor.ui.layouts import confirm_signverify
from apps.common.paths import validate_path
from apps.common.signverify import decode_message, message_digest
from .addresses import address_short, get_address
from .keychain import validate_path_against_script_type
message = msg.message message = msg.message
address_n = msg.address_n address_n = msg.address_n
script_type = msg.script_type or InputScriptType.SPENDADDRESS script_type = msg.script_type or InputScriptType.SPENDADDRESS

View File

@ -1,12 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.enums import RequestType
from trezor.messages import TxRequest
from ..common import BITCOIN_NAMES
from ..keychain import with_keychain from ..keychain import with_keychain
from . import approvers, bitcoin, helpers, progress
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from . import bitcoinlike, decred, zcash_v4 from . import bitcoinlike, decred, zcash_v4
@ -15,6 +11,7 @@ if not utils.BITCOIN_ONLY:
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Protocol from typing import Protocol
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
SignTx, SignTx,
TxAckInput, TxAckInput,
@ -23,11 +20,13 @@ if TYPE_CHECKING:
TxAckPrevInput, TxAckPrevInput,
TxAckPrevOutput, TxAckPrevOutput,
TxAckPrevExtraData, TxAckPrevExtraData,
TxRequest,
) )
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from . import approvers
from ..authorization import CoinJoinAuthorization from ..authorization import CoinJoinAuthorization
TxAckType = ( TxAckType = (
@ -55,12 +54,18 @@ if TYPE_CHECKING:
@with_keychain @with_keychain
async def sign_tx( async def sign_tx(
ctx: wire.Context, ctx: Context,
msg: SignTx, msg: SignTx,
keychain: Keychain, keychain: Keychain,
coin: CoinInfo, coin: CoinInfo,
authorization: CoinJoinAuthorization | None = None, authorization: CoinJoinAuthorization | None = None,
) -> TxRequest: ) -> TxRequest:
from trezor.enums import RequestType
from trezor.messages import TxRequest
from ..common import BITCOIN_NAMES
from . import approvers, bitcoin, helpers, progress
approver: approvers.Approver | None = None approver: approvers.Approver | None = None
if authorization: if authorization:
approver = approvers.CoinJoinApprover(msg, coin, authorization) approver = approvers.CoinJoinApprover(msg, coin, authorization)

View File

@ -1,23 +1,19 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.curve import bip340, secp256k1 from trezor.crypto.curve import bip340, secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import OutputScriptType
from trezor.ui.components.common.confirm import INFO
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.wire import DataError, ProcessError
from apps.common import safety_checks from apps.common import safety_checks
from .. import writers from .. import writers
from ..authorization import FEE_RATE_DECIMALS
from ..common import input_is_external_unverified from ..common import input_is_external_unverified
from ..keychain import validate_path_against_script_type from ..keychain import validate_path_against_script_type
from . import helpers, tx_weight from . import helpers, tx_weight
from .payment_request import PaymentRequestVerifier
from .sig_hasher import BitcoinSigHasher from .sig_hasher import BitcoinSigHasher
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.crypto import bip32 from trezor.crypto import bip32
@ -27,6 +23,8 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from ..authorization import CoinJoinAuthorization from ..authorization import CoinJoinAuthorization
from .tx_info import TxInfo
from .payment_request import PaymentRequestVerifier
# An Approver object computes the transaction totals and either prompts the user # An Approver object computes the transaction totals and either prompts the user
@ -79,7 +77,7 @@ class Approver:
if input_is_external_unverified(txi): if input_is_external_unverified(txi):
self.has_unverified_external_input = True self.has_unverified_external_input = True
if safety_checks.is_strict(): if safety_checks.is_strict():
raise wire.ProcessError("Unverifiable external input.") raise ProcessError("Unverifiable external input.")
else: else:
self.external_in += txi.amount self.external_in += txi.amount
if txi.orig_hash: if txi.orig_hash:
@ -92,6 +90,8 @@ class Approver:
async def add_payment_request( async def add_payment_request(
self, msg: TxAckPaymentRequest, keychain: Keychain self, msg: TxAckPaymentRequest, keychain: Keychain
) -> None: ) -> None:
from .payment_request import PaymentRequestVerifier
self.finish_payment_request() self.finish_payment_request()
self.payment_req_verifier = PaymentRequestVerifier(msg, self.coin, keychain) self.payment_req_verifier = PaymentRequestVerifier(msg, self.coin, keychain)
@ -135,7 +135,7 @@ class Approver:
class BasicApprover(Approver): class BasicApprover(Approver):
# the maximum number of change-outputs allowed without user confirmation # the maximum number of change-outputs allowed without user confirmation
MAX_SILENT_CHANGE_COUNT = const(2) MAX_SILENT_CHANGE_COUNT = 2
def __init__(self, tx: SignTx, coin: CoinInfo) -> None: def __init__(self, tx: SignTx, coin: CoinInfo) -> None:
super().__init__(tx, coin) super().__init__(tx, coin)
@ -158,7 +158,7 @@ class BasicApprover(Approver):
not validate_path_against_script_type(self.coin, txi) not validate_path_against_script_type(self.coin, txi)
and not self.foreign_address_confirmed and not self.foreign_address_confirmed
): ):
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
super().add_change_output(txo, script_pubkey) super().add_change_output(txo, script_pubkey)
@ -170,6 +170,8 @@ class BasicApprover(Approver):
script_pubkey: bytes, script_pubkey: bytes,
orig_txo: TxOutput | None = None, orig_txo: TxOutput | None = None,
) -> None: ) -> None:
from trezor.enums import OutputScriptType
await super().add_external_output(txo, script_pubkey, orig_txo) await super().add_external_output(txo, script_pubkey, orig_txo)
if orig_txo: if orig_txo:
@ -180,7 +182,7 @@ class BasicApprover(Approver):
if self.is_payjoin(): if self.is_payjoin():
# In case of PayJoin the above could be used to increase other external # In case of PayJoin the above could be used to increase other external
# outputs, which would create too much UI complexity. # outputs, which would create too much UI complexity.
raise wire.ProcessError( raise ProcessError(
"Reducing original output amounts is not supported." "Reducing original output amounts is not supported."
) )
await helpers.confirm_modify_output( await helpers.confirm_modify_output(
@ -191,7 +193,7 @@ class BasicApprover(Approver):
# confirmation, because approve_tx() together with the branch above ensures that # confirmation, because approve_tx() together with the branch above ensures that
# the increase is paid by external inputs. # the increase is paid by external inputs.
if not self.is_payjoin(): if not self.is_payjoin():
raise wire.ProcessError( raise ProcessError(
"Increasing original output amounts is not supported." "Increasing original output amounts is not supported."
) )
@ -199,7 +201,7 @@ class BasicApprover(Approver):
# Skip output confirmation for replacement transactions, # Skip output confirmation for replacement transactions,
# but don't allow adding new OP_RETURN outputs. # but don't allow adding new OP_RETURN outputs.
if txo.script_type == OutputScriptType.PAYTOOPRETURN and not orig_txo: if txo.script_type == OutputScriptType.PAYTOOPRETURN and not orig_txo:
raise wire.ProcessError( raise ProcessError(
"Adding new OP_RETURN outputs in replacement transactions is not supported." "Adding new OP_RETURN outputs in replacement transactions is not supported."
) )
elif txo.payment_req_index is None or self.show_payment_req_details: elif txo.payment_req_index is None or self.show_payment_req_details:
@ -210,9 +212,11 @@ class BasicApprover(Approver):
async def add_payment_request( async def add_payment_request(
self, msg: TxAckPaymentRequest, keychain: Keychain self, msg: TxAckPaymentRequest, keychain: Keychain
) -> None: ) -> None:
from trezor.ui.components.common.confirm import INFO
await super().add_payment_request(msg, keychain) await super().add_payment_request(msg, keychain)
if msg.amount is None: if msg.amount is None:
raise wire.DataError("Missing payment request amount.") raise DataError("Missing payment request amount.")
result = await helpers.confirm_payment_request(msg, self.coin, self.amount_unit) result = await helpers.confirm_payment_request(msg, self.coin, self.amount_unit)
self.show_payment_req_details = result is INFO self.show_payment_req_details = result is INFO
@ -238,6 +242,11 @@ class BasicApprover(Approver):
await helpers.confirm_replacement(description, orig.orig_hash) await helpers.confirm_replacement(description, orig.orig_hash)
async def approve_tx(self, tx_info: TxInfo, orig_txs: list[OriginalTxInfo]) -> None: async def approve_tx(self, tx_info: TxInfo, orig_txs: list[OriginalTxInfo]) -> None:
from trezor.wire import NotEnoughFunds
coin = self.coin # local_cache_attribute
amount_unit = self.amount_unit # local_cache_attribute
await super().approve_tx(tx_info, orig_txs) await super().approve_tx(tx_info, orig_txs)
if self.has_unverified_external_input: if self.has_unverified_external_input:
@ -246,21 +255,21 @@ class BasicApprover(Approver):
fee = self.total_in - self.total_out fee = self.total_in - self.total_out
# some coins require negative fees for reward TX # some coins require negative fees for reward TX
if fee < 0 and not self.coin.negative_fee: if fee < 0 and not coin.negative_fee:
raise wire.NotEnoughFunds("Not enough funds") raise NotEnoughFunds("Not enough funds")
total = self.total_in - self.change_out total = self.total_in - self.change_out
spending = total - self.external_in spending = total - self.external_in
tx_size_vB = self.weight.get_virtual_size() tx_size_vB = self.weight.get_virtual_size()
fee_rate = fee / tx_size_vB fee_rate = fee / tx_size_vB
# fee_threshold = (coin.maxfee per byte * tx size) # fee_threshold = (coin.maxfee per byte * tx size)
fee_threshold = (self.coin.maxfee_kb / 1000) * tx_size_vB fee_threshold = (coin.maxfee_kb / 1000) * tx_size_vB
# fee > (coin.maxfee per byte * tx size) # fee > (coin.maxfee per byte * tx size)
if fee > fee_threshold: if fee > fee_threshold:
if fee > 10 * fee_threshold and safety_checks.is_strict(): if fee > 10 * fee_threshold and safety_checks.is_strict():
raise wire.DataError("The fee is unexpectedly large") raise DataError("The fee is unexpectedly large")
await helpers.confirm_feeoverthreshold(fee, self.coin, self.amount_unit) await helpers.confirm_feeoverthreshold(fee, coin, amount_unit)
if self.change_count > self.MAX_SILENT_CHANGE_COUNT: if self.change_count > self.MAX_SILENT_CHANGE_COUNT:
await helpers.confirm_change_count_over_threshold(self.change_count) await helpers.confirm_change_count_over_threshold(self.change_count)
@ -273,7 +282,7 @@ class BasicApprover(Approver):
orig_fee = self.orig_total_in - self.orig_total_out orig_fee = self.orig_total_in - self.orig_total_out
if fee < 0 or orig_fee < 0: if fee < 0 or orig_fee < 0:
raise wire.ProcessError( raise ProcessError(
"Negative fees not supported in transaction replacement." "Negative fees not supported in transaction replacement."
) )
@ -283,14 +292,14 @@ class BasicApprover(Approver):
# not increase by more than the fee difference (so additional funds # not increase by more than the fee difference (so additional funds
# can only go towards the fee, which is confirmed by the user). # can only go towards the fee, which is confirmed by the user).
if spending - orig_spending > fee - orig_fee: if spending - orig_spending > fee - orig_fee:
raise wire.ProcessError("Invalid replacement transaction.") raise ProcessError("Invalid replacement transaction.")
# Replacement transactions must not change the effective nLockTime. # Replacement transactions must not change the effective nLockTime.
lock_time = 0 if tx_info.lock_time_disabled() else tx_info.tx.lock_time lock_time = 0 if tx_info.lock_time_disabled() else tx_info.tx.lock_time
for orig in orig_txs: for orig in orig_txs:
orig_lock_time = 0 if orig.lock_time_disabled() else orig.tx.lock_time orig_lock_time = 0 if orig.lock_time_disabled() else orig.tx.lock_time
if lock_time != orig_lock_time: if lock_time != orig_lock_time:
raise wire.ProcessError( raise ProcessError(
"Original transactions must have same effective nLockTime as replacement transaction." "Original transactions must have same effective nLockTime as replacement transaction."
) )
@ -299,14 +308,14 @@ class BasicApprover(Approver):
# coming entirely from the user's own funds and from decreases of external outputs. # coming entirely from the user's own funds and from decreases of external outputs.
# We consider the decreases as belonging to the user. # We consider the decreases as belonging to the user.
await helpers.confirm_modify_fee( await helpers.confirm_modify_fee(
fee - orig_fee, fee, fee_rate, self.coin, self.amount_unit fee - orig_fee, fee, fee_rate, coin, amount_unit
) )
elif spending > orig_spending: elif spending > orig_spending:
# PayJoin and user is spending more: Show the increase in the user's contribution # PayJoin and user is spending more: Show the increase in the user's contribution
# to the fee, ignoring any contribution from external inputs. Decreasing of # to the fee, ignoring any contribution from external inputs. Decreasing of
# external outputs is not allowed in PayJoin, so there is no need to handle those. # external outputs is not allowed in PayJoin, so there is no need to handle those.
await helpers.confirm_modify_fee( await helpers.confirm_modify_fee(
spending - orig_spending, fee, fee_rate, self.coin, self.amount_unit spending - orig_spending, fee, fee_rate, coin, amount_unit
) )
else: else:
# PayJoin and user is not spending more: When new external inputs are involved and # PayJoin and user is not spending more: When new external inputs are involved and
@ -321,13 +330,9 @@ class BasicApprover(Approver):
) )
if not self.external_in: if not self.external_in:
await helpers.confirm_total( await helpers.confirm_total(total, fee, fee_rate, coin, amount_unit)
total, fee, fee_rate, self.coin, self.amount_unit
)
else: else:
await helpers.confirm_joint_total( await helpers.confirm_joint_total(spending, total, coin, amount_unit)
spending, total, self.coin, self.amount_unit
)
class CoinJoinApprover(Approver): class CoinJoinApprover(Approver):
@ -336,7 +341,7 @@ class CoinJoinApprover(Approver):
MIN_REGISTRABLE_OUTPUT_AMOUNT = const(5000) MIN_REGISTRABLE_OUTPUT_AMOUNT = const(5000)
# Largest possible weight of an output supported by Trezor (P2TR or P2WSH). # Largest possible weight of an output supported by Trezor (P2TR or P2WSH).
MAX_OUTPUT_WEIGHT = const(4 * (8 + 1 + 1 + 1 + 32)) MAX_OUTPUT_WEIGHT = 4 * (8 + 1 + 1 + 1 + 32)
# Masks for the signable and no_fee bits in coinjoin_flags. # Masks for the signable and no_fee bits in coinjoin_flags.
COINJOIN_FLAGS_SIGNABLE = const(0x01) COINJOIN_FLAGS_SIGNABLE = const(0x01)
@ -356,7 +361,7 @@ class CoinJoinApprover(Approver):
super().__init__(tx, coin) super().__init__(tx, coin)
if not tx.coinjoin_request: if not tx.coinjoin_request:
raise wire.DataError("Missing CoinJoin request.") raise DataError("Missing CoinJoin request.")
self.request = tx.coinjoin_request self.request = tx.coinjoin_request
self.authorization = authorization self.authorization = authorization
@ -384,7 +389,7 @@ class CoinJoinApprover(Approver):
async def add_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None: async def add_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None:
self.our_weight.add_input(txi) self.our_weight.add_input(txi)
if not self.authorization.check_sign_tx_input(txi, self.coin): if not self.authorization.check_sign_tx_input(txi, self.coin):
raise wire.ProcessError("Unauthorized path") raise ProcessError("Unauthorized path")
# Compute the masking bit for the signable bit in coinjoin flags. # Compute the masking bit for the signable bit in coinjoin flags.
internal_private_key = node.private_key() internal_private_key = node.private_key()
@ -400,7 +405,7 @@ class CoinJoinApprover(Approver):
# Ensure that the input can be signed. # Ensure that the input can be signed.
if bool(txi.coinjoin_flags & self.COINJOIN_FLAGS_SIGNABLE) ^ mask != 1: if bool(txi.coinjoin_flags & self.COINJOIN_FLAGS_SIGNABLE) ^ mask != 1:
raise wire.ProcessError("Unauthorized input") raise ProcessError("Unauthorized input")
# Add to coordination_fee_base, except for remixes and small inputs which are # Add to coordination_fee_base, except for remixes and small inputs which are
# not charged a coordination fee. # not charged a coordination fee.
@ -416,7 +421,7 @@ class CoinJoinApprover(Approver):
# in multiple signatures schemes (ECDSA and Schnorr) and we want to be sure that the user # in multiple signatures schemes (ECDSA and Schnorr) and we want to be sure that the user
# went through a warning screen before we sign the input. # went through a warning screen before we sign the input.
if not self.authorization.check_sign_tx_input(txi, self.coin): if not self.authorization.check_sign_tx_input(txi, self.coin):
raise wire.ProcessError("Unauthorized path") raise ProcessError("Unauthorized path")
def add_external_input(self, txi: TxInput) -> None: def add_external_input(self, txi: TxInput) -> None:
super().add_external_input(txi) super().add_external_input(txi)
@ -425,7 +430,7 @@ class CoinJoinApprover(Approver):
# is not critical for security, we are just being cautious, because # is not critical for security, we are just being cautious, because
# CoinJoin is automated and this is not a very legitimate use-case. # CoinJoin is automated and this is not a very legitimate use-case.
if input_is_external_unverified(txi): if input_is_external_unverified(txi):
raise wire.ProcessError("Unverifiable external input.") raise ProcessError("Unverifiable external input.")
def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
super().add_change_output(txo, script_pubkey) super().add_change_output(txo, script_pubkey)
@ -438,7 +443,7 @@ class CoinJoinApprover(Approver):
def _verify_coinjoin_request(self, tx_info: TxInfo): def _verify_coinjoin_request(self, tx_info: TxInfo):
if not isinstance(tx_info.sig_hasher, BitcoinSigHasher): if not isinstance(tx_info.sig_hasher, BitcoinSigHasher):
raise wire.ProcessError("Unexpected signature hasher.") raise ProcessError("Unexpected signature hasher.")
# Finish hashing the CoinJoin request. # Finish hashing the CoinJoin request.
writers.write_bytes_fixed( writers.write_bytes_fixed(
@ -464,10 +469,12 @@ class CoinJoinApprover(Approver):
) )
async def approve_tx(self, tx_info: TxInfo, orig_txs: list[OriginalTxInfo]) -> None: async def approve_tx(self, tx_info: TxInfo, orig_txs: list[OriginalTxInfo]) -> None:
from ..authorization import FEE_RATE_DECIMALS
await super().approve_tx(tx_info, orig_txs) await super().approve_tx(tx_info, orig_txs)
if not self._verify_coinjoin_request(tx_info): if not self._verify_coinjoin_request(tx_info):
raise wire.DataError("Invalid signature in CoinJoin request.") raise DataError("Invalid signature in CoinJoin request.")
# The mining fee of the transaction as a whole. # The mining fee of the transaction as a whole.
mining_fee = self.total_in - self.total_out mining_fee = self.total_in - self.total_out
@ -512,13 +519,13 @@ class CoinJoinApprover(Approver):
+ our_max_mining_fee + our_max_mining_fee
+ min_allowed_output_amount_plus_fee + min_allowed_output_amount_plus_fee
): ):
raise wire.ProcessError("Total fee over threshold.") raise ProcessError("Total fee over threshold.")
if not self.authorization.approve_sign_tx(tx_info.tx): if not self.authorization.approve_sign_tx(tx_info.tx):
raise wire.ProcessError("Exceeded number of CoinJoin rounds.") raise ProcessError("Exceeded number of CoinJoin rounds.")
def _add_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def _add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
super()._add_output(txo, script_pubkey) super()._add_output(txo, script_pubkey)
if txo.payment_req_index: if txo.payment_req_index:
raise wire.DataError("Unexpected payment request.") raise DataError("Unexpected payment request.")

View File

@ -1,28 +1,21 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType
from trezor.messages import TxRequest, TxRequestDetailsType, TxRequestSerializedType from trezor.utils import HashWriter, empty_bytearray
from trezor.utils import HashWriter, empty_bytearray, ensure from trezor.wire import DataError, ProcessError
from apps.common.writers import write_compact_size from apps.common.writers import write_compact_size
from .. import addresses, common, multisig, scripts, writers from .. import addresses, common, multisig, scripts, writers
from ..common import ( from ..common import SigHashType, ecdsa_sign, input_is_external
SigHashType,
bip340_sign,
ecdsa_sign,
input_is_external,
input_is_segwit,
)
from ..ownership import verify_nonownership from ..ownership import verify_nonownership
from ..verification import SignatureVerifier from ..verification import SignatureVerifier
from . import approvers, helpers from . import helpers
from .helpers import request_tx_input, request_tx_output
from .progress import progress from .progress import progress
from .sig_hasher import BitcoinSigHasher from .tx_info import OriginalTxInfo
from .tx_info import OriginalTxInfo, TxInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence
@ -41,7 +34,10 @@ if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from . import approvers
from .sig_hasher import SigHasher from .sig_hasher import SigHasher
from .tx_info import TxInfo
from ..writers import Writer
# the number of bytes to preallocate for serialized transaction chunks # the number of bytes to preallocate for serialized transaction chunks
@ -101,6 +97,14 @@ class Bitcoin:
coin: CoinInfo, coin: CoinInfo,
approver: approvers.Approver | None, approver: approvers.Approver | None,
) -> None: ) -> None:
from trezor.messages import (
TxRequest,
TxRequestDetailsType,
TxRequestSerializedType,
)
from . import approvers
from .tx_info import TxInfo
global _SERIALIZED_TX_BUFFER global _SERIALIZED_TX_BUFFER
self.tx_info = TxInfo(self, helpers.sanitize_sign_tx(tx, coin)) self.tx_info = TxInfo(self, helpers.sanitize_sign_tx(tx, coin))
@ -151,15 +155,20 @@ class Bitcoin:
return HashWriter(sha256()) return HashWriter(sha256())
def create_sig_hasher(self, tx: SignTx | PrevTx) -> SigHasher: def create_sig_hasher(self, tx: SignTx | PrevTx) -> SigHasher:
from .sig_hasher import BitcoinSigHasher
return BitcoinSigHasher() return BitcoinSigHasher()
async def step1_process_inputs(self) -> None: async def step1_process_inputs(self) -> None:
from ..common import input_is_segwit
tx_info = self.tx_info # local_cache_attribute
h_presigned_inputs_check = HashWriter(sha256()) h_presigned_inputs_check = HashWriter(sha256())
for i in range(self.tx_info.tx.inputs_count): for i in range(tx_info.tx.inputs_count):
# STAGE_REQUEST_1_INPUT in legacy # STAGE_REQUEST_1_INPUT in legacy
progress.advance() progress.advance()
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
if txi.script_type not in ( if txi.script_type not in (
InputScriptType.SPENDTAPROOT, InputScriptType.SPENDTAPROOT,
InputScriptType.EXTERNAL, InputScriptType.EXTERNAL,
@ -186,14 +195,14 @@ class Bitcoin:
if txi.orig_hash: if txi.orig_hash:
await self.process_original_input(txi, script_pubkey) await self.process_original_input(txi, script_pubkey)
self.tx_info.h_inputs_check = self.tx_info.get_tx_check_digest() tx_info.h_inputs_check = tx_info.get_tx_check_digest()
self.h_presigned_inputs = h_presigned_inputs_check.get_digest() self.h_presigned_inputs = h_presigned_inputs_check.get_digest()
# Finalize original inputs. # Finalize original inputs.
for orig in self.orig_txs: for orig in self.orig_txs:
orig.h_inputs_check = orig.get_tx_check_digest() orig.h_inputs_check = orig.get_tx_check_digest()
if orig.index != orig.tx.inputs_count: if orig.index != orig.tx.inputs_count:
raise wire.ProcessError("Removal of original inputs is not supported.") raise ProcessError("Removal of original inputs is not supported.")
orig.index = 0 # Reset counter for outputs. orig.index = 0 # Reset counter for outputs.
@ -201,7 +210,7 @@ class Bitcoin:
for i in range(self.tx_info.tx.outputs_count): for i in range(self.tx_info.tx.outputs_count):
# STAGE_REQUEST_2_OUTPUT in legacy # STAGE_REQUEST_2_OUTPUT in legacy
progress.advance() progress.advance()
txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo = await request_tx_output(self.tx_req, i, self.coin)
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
orig_txo: TxOutput | None = None orig_txo: TxOutput | None = None
if txo.orig_hash: if txo.orig_hash:
@ -228,7 +237,7 @@ class Bitcoin:
for i in range(self.tx_info.tx.inputs_count): for i in range(self.tx_info.tx.inputs_count):
if i in self.presigned: if i in self.presigned:
progress.advance() progress.advance()
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
# txi.script_pubkey checked in sanitize_tx_input # txi.script_pubkey checked in sanitize_tx_input
@ -247,24 +256,24 @@ class Bitcoin:
# multiple rounds of the attack. # multiple rounds of the attack.
expected_digest = self.tx_info.h_inputs_check expected_digest = self.tx_info.h_inputs_check
for i in range(self.tx_info.tx.inputs_count): for i in range(self.tx_info.tx.inputs_count):
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
prev_amount, script_pubkey = await self.get_prevtx_output( prev_amount, script_pubkey = await self.get_prevtx_output(
txi.prev_hash, txi.prev_index txi.prev_hash, txi.prev_index
) )
if prev_amount != txi.amount: if prev_amount != txi.amount:
raise wire.DataError("Invalid amount specified") raise DataError("Invalid amount specified")
if script_pubkey != self.input_derive_script(txi): if script_pubkey != self.input_derive_script(txi):
raise wire.DataError("Input does not match scriptPubKey") raise DataError("Input does not match scriptPubKey")
if i in self.presigned: if i in self.presigned:
await self.verify_presigned_external_input(i, txi, script_pubkey) await self.verify_presigned_external_input(i, txi, script_pubkey)
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if h_check.get_digest() != expected_digest: if h_check.get_digest() != expected_digest:
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
# verify the signature of one SIGHASH_ALL input in each original transaction # verify the signature of one SIGHASH_ALL input in each original transaction
await self.verify_original_txs() await self.verify_original_txs()
@ -306,9 +315,7 @@ class Bitcoin:
if self.serialize: if self.serialize:
if i in self.presigned: if i in self.presigned:
progress.advance() progress.advance()
txi = await helpers.request_tx_input( txi = await request_tx_input(self.tx_req, i, self.coin)
self.tx_req, i, self.coin
)
self.serialized_tx.extend(txi.witness or b"\0") self.serialized_tx.extend(txi.witness or b"\0")
else: else:
self.serialized_tx.append(0) self.serialized_tx.append(0)
@ -329,7 +336,7 @@ class Bitcoin:
async def process_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None: async def process_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None:
if txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES: if txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Wrong input script type") raise DataError("Wrong input script type")
await self.approver.add_internal_input(txi, node) await self.approver.add_internal_input(txi, node)
@ -346,33 +353,32 @@ class Bitcoin:
self.keychain, self.keychain,
self.coin, self.coin,
): ):
raise wire.DataError("Invalid external input") raise DataError("Invalid external input")
async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None: async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None:
assert txi.orig_hash is not None orig_hash = txi.orig_hash # local_cache_attribute
assert txi.orig_index is not None orig_index = txi.orig_index # local_cache_attribute
assert orig_hash is not None
assert orig_index is not None
for orig in self.orig_txs: for orig in self.orig_txs:
if orig.orig_hash == txi.orig_hash: if orig.orig_hash == orig_hash:
break break
else: else:
orig_meta = await helpers.request_tx_meta( orig_meta = await helpers.request_tx_meta(self.tx_req, self.coin, orig_hash)
self.tx_req, self.coin, txi.orig_hash orig = OriginalTxInfo(self, orig_meta, orig_hash)
)
orig = OriginalTxInfo(self, orig_meta, txi.orig_hash)
self.orig_txs.append(orig) self.orig_txs.append(orig)
if txi.orig_index >= orig.tx.inputs_count: if orig_index >= orig.tx.inputs_count:
raise wire.ProcessError("Not enough inputs in original transaction.") raise ProcessError("Not enough inputs in original transaction.")
if orig.index != txi.orig_index: if orig.index != orig_index:
raise wire.ProcessError( raise ProcessError(
"Rearranging or removal of original inputs is not supported." "Rearranging or removal of original inputs is not supported."
) )
orig_txi = await helpers.request_tx_input( orig_txi = await request_tx_input(self.tx_req, orig_index, self.coin, orig_hash)
self.tx_req, txi.orig_index, self.coin, txi.orig_hash
)
# Verify that the original input matches: # Verify that the original input matches:
# #
@ -390,7 +396,7 @@ class Bitcoin:
or orig_txi.script_type != txi.script_type or orig_txi.script_type != txi.script_type
or self.input_derive_script(orig_txi) != script_pubkey or self.input_derive_script(orig_txi) != script_pubkey
): ):
raise wire.ProcessError("Original input does not match current input.") raise ProcessError("Original input does not match current input.")
orig.add_input(orig_txi, script_pubkey) orig.add_input(orig_txi, script_pubkey)
orig.index += 1 orig.index += 1
@ -399,9 +405,7 @@ class Bitcoin:
self, orig: OriginalTxInfo, orig_hash: bytes, last_index: int self, orig: OriginalTxInfo, orig_hash: bytes, last_index: int
) -> None: ) -> None:
while orig.index < last_index: while orig.index < last_index:
txo = await helpers.request_tx_output( txo = await request_tx_output(self.tx_req, orig.index, self.coin, orig_hash)
self.tx_req, orig.index, self.coin, orig_hash
)
orig.add_output(txo, self.output_derive_script(txo)) orig.add_output(txo, self.output_derive_script(txo))
if orig.output_is_change(txo): if orig.output_is_change(txo):
@ -409,7 +413,7 @@ class Bitcoin:
self.approver.add_orig_change_output(txo) self.approver.add_orig_change_output(txo)
else: else:
# Removal of external outputs requires prompting the user. Not implemented. # Removal of external outputs requires prompting the user. Not implemented.
raise wire.ProcessError( raise ProcessError(
"Removal of original external outputs is not supported." "Removal of original external outputs is not supported."
) )
@ -418,35 +422,36 @@ class Bitcoin:
async def get_original_output( async def get_original_output(
self, txo: TxOutput, script_pubkey: bytes self, txo: TxOutput, script_pubkey: bytes
) -> TxOutput: ) -> TxOutput:
assert txo.orig_hash is not None orig_hash = txo.orig_hash # local_cache_attribute
assert txo.orig_index is not None orig_index = txo.orig_index # local_cache_attribute
assert orig_hash is not None
assert orig_index is not None
for orig in self.orig_txs: for orig in self.orig_txs:
if orig.orig_hash == txo.orig_hash: if orig.orig_hash == orig_hash:
break break
else: else:
raise wire.ProcessError("Unknown original transaction.") raise ProcessError("Unknown original transaction.")
if txo.orig_index >= orig.tx.outputs_count: if orig_index >= orig.tx.outputs_count:
raise wire.ProcessError("Not enough outputs in original transaction.") raise ProcessError("Not enough outputs in original transaction.")
if orig.index > txo.orig_index: if orig.index > orig_index:
raise wire.ProcessError("Rearranging of original outputs is not supported.") raise ProcessError("Rearranging of original outputs is not supported.")
# First fetch any removed original outputs which precede the one we want. # First fetch any removed original outputs which precede the one we want.
await self.fetch_removed_original_outputs(orig, txo.orig_hash, txo.orig_index) await self.fetch_removed_original_outputs(orig, orig_hash, orig_index)
orig_txo = await helpers.request_tx_output( orig_txo = await request_tx_output(
self.tx_req, orig.index, self.coin, txo.orig_hash self.tx_req, orig.index, self.coin, orig_hash
) )
if script_pubkey != self.output_derive_script(orig_txo): if script_pubkey != self.output_derive_script(orig_txo):
raise wire.ProcessError("Not an original output.") raise ProcessError("Not an original output.")
if self.tx_info.output_is_change(txo) and not orig.output_is_change(orig_txo): if self.tx_info.output_is_change(txo) and not orig.output_is_change(orig_txo):
raise wire.ProcessError( raise ProcessError("Original output is missing change-output parameters.")
"Original output is missing change-output parameters."
)
orig.add_output(orig_txo, script_pubkey) orig.add_output(orig_txo, script_pubkey)
@ -466,9 +471,7 @@ class Bitcoin:
for i in range(orig.tx.inputs_count): for i in range(orig.tx.inputs_count):
progress.advance() progress.advance()
txi = await helpers.request_tx_input( txi = await request_tx_input(self.tx_req, i, self.coin, orig.orig_hash)
self.tx_req, i, self.coin, orig.orig_hash
)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
script_pubkey = self.input_derive_script(txi) script_pubkey = self.input_derive_script(txi)
verifier = SignatureVerifier( verifier = SignatureVerifier(
@ -489,7 +492,7 @@ class Bitcoin:
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if h_check.get_digest() != orig.h_inputs_check: if h_check.get_digest() != orig.h_inputs_check:
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
async def approve_output( async def approve_output(
self, self,
@ -497,23 +500,24 @@ class Bitcoin:
script_pubkey: bytes, script_pubkey: bytes,
orig_txo: TxOutput | None, orig_txo: TxOutput | None,
) -> None: ) -> None:
if txo.payment_req_index != self.payment_req_index: payment_req_index = txo.payment_req_index # local_cache_attribute
if txo.payment_req_index is None: approver = self.approver # local_cache_attribute
if payment_req_index != self.payment_req_index:
if payment_req_index is None:
self.approver.finish_payment_request() self.approver.finish_payment_request()
else: else:
tx_ack_payment_req = await helpers.request_payment_req( tx_ack_payment_req = await helpers.request_payment_req(
self.tx_req, txo.payment_req_index self.tx_req, payment_req_index
) )
await self.approver.add_payment_request( await approver.add_payment_request(tx_ack_payment_req, self.keychain)
tx_ack_payment_req, self.keychain self.payment_req_index = payment_req_index
)
self.payment_req_index = txo.payment_req_index
if self.tx_info.output_is_change(txo): if self.tx_info.output_is_change(txo):
# Output is change and does not need approval. # Output is change and does not need approval.
self.approver.add_change_output(txo, script_pubkey) approver.add_change_output(txo, script_pubkey)
else: else:
await self.approver.add_external_output(txo, script_pubkey, orig_txo) await approver.add_external_output(txo, script_pubkey, orig_txo)
self.tx_info.add_output(txo, script_pubkey) self.tx_info.add_output(txo, script_pubkey)
@ -568,18 +572,18 @@ class Bitcoin:
verifier.verify(tx_digest) verifier.verify(tx_digest)
async def serialize_external_input(self, i: int) -> None: async def serialize_external_input(self, i: int) -> None:
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
if not input_is_external(txi): if not input_is_external(txi):
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
self.write_tx_input(self.serialized_tx, txi, txi.script_sig or bytes()) self.write_tx_input(self.serialized_tx, txi, txi.script_sig or bytes())
async def serialize_segwit_input(self, i: int) -> None: async def serialize_segwit_input(self, i: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT in legacy # STAGE_REQUEST_SEGWIT_INPUT in legacy
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
if txi.script_type not in common.SEGWIT_INPUT_SCRIPT_TYPES: if txi.script_type not in common.SEGWIT_INPUT_SCRIPT_TYPES:
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
self.tx_info.check_input(txi) self.tx_info.check_input(txi)
if txi.script_type == InputScriptType.SPENDP2SHWITNESS: if txi.script_type == InputScriptType.SPENDP2SHWITNESS:
@ -595,7 +599,7 @@ class Bitcoin:
if self.taproot_only: if self.taproot_only:
# Prevents an attacker from bypassing prev tx checking by providing a different # Prevents an attacker from bypassing prev tx checking by providing a different
# script type than the one that was provided during the confirmation phase. # script type than the one that was provided during the confirmation phase.
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
public_key = node.public_key() public_key = node.public_key()
@ -621,6 +625,8 @@ class Bitcoin:
return public_key, signature return public_key, signature
def sign_taproot_input(self, i: int, txi: TxInput) -> bytes: def sign_taproot_input(self, i: int, txi: TxInput) -> bytes:
from ..common import bip340_sign
sigmsg_digest = self.tx_info.sig_hasher.hash341( sigmsg_digest = self.tx_info.sig_hasher.hash341(
i, i,
self.tx_info.tx, self.tx_info.tx,
@ -632,11 +638,11 @@ class Bitcoin:
async def sign_segwit_input(self, i: int) -> None: async def sign_segwit_input(self, i: int) -> None:
# STAGE_REQUEST_SEGWIT_WITNESS in legacy # STAGE_REQUEST_SEGWIT_WITNESS in legacy
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await request_tx_input(self.tx_req, i, self.coin)
self.tx_info.check_input(txi) self.tx_info.check_input(txi)
self.approver.check_internal_input(txi) self.approver.check_internal_input(txi)
if txi.script_type not in common.SEGWIT_INPUT_SCRIPT_TYPES: if txi.script_type not in common.SEGWIT_INPUT_SCRIPT_TYPES:
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
if txi.script_type == InputScriptType.SPENDTAPROOT: if txi.script_type == InputScriptType.SPENDTAPROOT:
signature = self.sign_taproot_input(i, txi) signature = self.sign_taproot_input(i, txi)
@ -675,6 +681,9 @@ class Bitcoin:
tx_info: TxInfo | OriginalTxInfo, tx_info: TxInfo | OriginalTxInfo,
script_pubkey: bytes | None = None, script_pubkey: bytes | None = None,
) -> tuple[bytes, TxInput, bip32.HDNode | None]: ) -> tuple[bytes, TxInput, bip32.HDNode | None]:
tx = tx_info.tx # local_cache_attribute
coin = self.coin # local_cache_attribute
tx_hash = tx_info.orig_hash if isinstance(tx_info, OriginalTxInfo) else None tx_hash = tx_info.orig_hash if isinstance(tx_info, OriginalTxInfo) else None
# the transaction digest which gets signed for this input # the transaction digest which gets signed for this input
@ -682,15 +691,15 @@ class Bitcoin:
# should come out the same as h_tx_check, checked before signing the digest # should come out the same as h_tx_check, checked before signing the digest
h_check = HashWriter(sha256()) h_check = HashWriter(sha256())
self.write_tx_header(h_sign, tx_info.tx, witness_marker=False) self.write_tx_header(h_sign, tx, witness_marker=False)
write_compact_size(h_sign, tx_info.tx.inputs_count) write_compact_size(h_sign, tx.inputs_count)
txi_sign = None txi_sign = None
node = None node = None
for i in range(tx_info.tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT in legacy # STAGE_REQUEST_4_INPUT in legacy
progress.advance() progress.advance()
txi = await helpers.request_tx_input(self.tx_req, i, self.coin, tx_hash) txi = await request_tx_input(self.tx_req, i, coin, tx_hash)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
# Only the previous UTXO's scriptPubKey is included in h_sign. # Only the previous UTXO's scriptPubKey is included in h_sign.
if i == index: if i == index:
@ -699,54 +708,55 @@ class Bitcoin:
self.tx_info.check_input(txi) self.tx_info.check_input(txi)
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
if txi.multisig: txi_multisig = txi.multisig # local_cache_attribute
if txi_multisig:
# Sanity check to ensure we are signing with a key that is included in the multisig. # Sanity check to ensure we are signing with a key that is included in the multisig.
multisig.multisig_pubkey_index(txi.multisig, key_sign_pub) multisig.multisig_pubkey_index(txi_multisig, key_sign_pub)
if txi.script_type == InputScriptType.SPENDMULTISIG: if txi.script_type == InputScriptType.SPENDMULTISIG:
assert txi.multisig is not None # checked in sanitize_tx_input assert txi_multisig is not None # checked in _sanitize_tx_input
script_pubkey = scripts.output_script_multisig( script_pubkey = scripts.output_script_multisig(
multisig.multisig_get_pubkeys(txi.multisig), multisig.multisig_get_pubkeys(txi_multisig),
txi.multisig.m, txi_multisig.m,
) )
elif txi.script_type == InputScriptType.SPENDADDRESS: elif txi.script_type == InputScriptType.SPENDADDRESS:
script_pubkey = scripts.output_script_p2pkh( script_pubkey = scripts.output_script_p2pkh(
addresses.ecdsa_hash_pubkey(key_sign_pub, self.coin) addresses.ecdsa_hash_pubkey(key_sign_pub, coin)
) )
else: else:
raise wire.ProcessError("Unknown transaction type") raise ProcessError("Unknown transaction type")
self.write_tx_input(h_sign, txi, script_pubkey) self.write_tx_input(h_sign, txi, script_pubkey)
else: else:
self.write_tx_input(h_sign, txi, bytes()) self.write_tx_input(h_sign, txi, bytes())
if txi_sign is None: if txi_sign is None:
raise RuntimeError # index >= tx_info.tx.inputs_count raise RuntimeError # index >= tx_info_tx.inputs_count
write_compact_size(h_sign, tx_info.tx.outputs_count) write_compact_size(h_sign, tx.outputs_count)
for i in range(tx_info.tx.outputs_count): for i in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT in legacy # STAGE_REQUEST_4_OUTPUT in legacy
progress.advance() progress.advance()
txo = await helpers.request_tx_output(self.tx_req, i, self.coin, tx_hash) txo = await request_tx_output(self.tx_req, i, coin, tx_hash)
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
self.write_tx_output(h_check, txo, script_pubkey) self.write_tx_output(h_check, txo, script_pubkey)
self.write_tx_output(h_sign, txo, script_pubkey) self.write_tx_output(h_sign, txo, script_pubkey)
writers.write_uint32(h_sign, tx_info.tx.lock_time) writers.write_uint32(h_sign, tx.lock_time)
writers.write_uint32(h_sign, self.get_hash_type(txi_sign)) writers.write_uint32(h_sign, self.get_hash_type(txi_sign))
# check that the inputs were the same as those streamed for approval # check that the inputs were the same as those streamed for approval
if tx_info.get_tx_check_digest() != h_check.get_digest(): if tx_info.get_tx_check_digest() != h_check.get_digest():
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
tx_digest = writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double) tx_digest = writers.get_tx_hash(h_sign, coin.sign_hash_double)
return tx_digest, txi_sign, node return tx_digest, txi_sign, node
async def sign_nonsegwit_input(self, i: int) -> None: async def sign_nonsegwit_input(self, i: int) -> None:
if self.taproot_only: if self.taproot_only:
# Prevents an attacker from bypassing prev tx checking by providing a different # Prevents an attacker from bypassing prev tx checking by providing a different
# script type than the one that was provided during the confirmation phase. # script type than the one that was provided during the confirmation phase.
raise wire.ProcessError("Transaction has changed during signing") raise ProcessError("Transaction has changed during signing")
tx_digest, txi, node = await self.get_legacy_tx_digest(i, self.tx_info) tx_digest, txi, node = await self.get_legacy_tx_digest(i, self.tx_info)
assert node is not None assert node is not None
@ -763,21 +773,23 @@ class Bitcoin:
async def serialize_output(self, i: int) -> None: async def serialize_output(self, i: int) -> None:
# STAGE_REQUEST_5_OUTPUT in legacy # STAGE_REQUEST_5_OUTPUT in legacy
txo = await helpers.request_tx_output(self.tx_req, i, self.coin) txo = await request_tx_output(self.tx_req, i, self.coin)
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
async def get_prevtx_output( async def get_prevtx_output(
self, prev_hash: bytes, prev_index: int self, prev_hash: bytes, prev_index: int
) -> tuple[int, bytes]: ) -> tuple[int, bytes]:
coin = self.coin # local_cache_attribute
amount_out = 0 # output amount amount_out = 0 # output amount
# STAGE_REQUEST_3_PREV_META in legacy # STAGE_REQUEST_3_PREV_META in legacy
tx = await helpers.request_tx_meta(self.tx_req, self.coin, prev_hash) tx = await helpers.request_tx_meta(self.tx_req, coin, prev_hash)
progress.init_prev_tx(tx.inputs_count, tx.outputs_count) progress.init_prev_tx(tx.inputs_count, tx.outputs_count)
if tx.outputs_count <= prev_index: if tx.outputs_count <= prev_index:
raise wire.ProcessError("Not enough outputs in previous transaction.") raise ProcessError("Not enough outputs in previous transaction.")
txh = self.create_hash_writer() txh = self.create_hash_writer()
@ -788,9 +800,7 @@ class Bitcoin:
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_3_PREV_INPUT in legacy # STAGE_REQUEST_3_PREV_INPUT in legacy
progress.advance_prev_tx() progress.advance_prev_tx()
txi = await helpers.request_tx_prev_input( txi = await helpers.request_tx_prev_input(self.tx_req, i, coin, prev_hash)
self.tx_req, i, self.coin, prev_hash
)
self.write_tx_input(txh, txi, txi.script_sig) self.write_tx_input(txh, txi, txi.script_sig)
write_compact_size(txh, tx.outputs_count) write_compact_size(txh, tx.outputs_count)
@ -800,7 +810,7 @@ class Bitcoin:
# STAGE_REQUEST_3_PREV_OUTPUT in legacy # STAGE_REQUEST_3_PREV_OUTPUT in legacy
progress.advance_prev_tx() progress.advance_prev_tx()
txo_bin = await helpers.request_tx_prev_output( txo_bin = await helpers.request_tx_prev_output(
self.tx_req, i, self.coin, prev_hash self.tx_req, i, coin, prev_hash
) )
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey) self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
if i == prev_index: if i == prev_index:
@ -812,11 +822,8 @@ class Bitcoin:
await self.write_prev_tx_footer(txh, tx, prev_hash) await self.write_prev_tx_footer(txh, tx, prev_hash)
if ( if writers.get_tx_hash(txh, coin.sign_hash_double, True) != prev_hash:
writers.get_tx_hash(txh, double=self.coin.sign_hash_double, reverse=True) raise ProcessError("Encountered invalid prev_hash")
!= prev_hash
):
raise wire.ProcessError("Encountered invalid prev_hash")
return amount_out, script_pubkey return amount_out, script_pubkey
@ -843,7 +850,7 @@ class Bitcoin:
def write_tx_input_derived( def write_tx_input_derived(
self, self,
w: writers.Writer, w: Writer,
txi: TxInput, txi: TxInput,
pubkey: bytes, pubkey: bytes,
signature: bytes, signature: bytes,
@ -863,7 +870,7 @@ class Bitcoin:
@staticmethod @staticmethod
def write_tx_input( def write_tx_input(
w: writers.Writer, w: Writer,
txi: TxInput | PrevInput, txi: TxInput | PrevInput,
script: bytes, script: bytes,
) -> None: ) -> None:
@ -871,7 +878,7 @@ class Bitcoin:
@staticmethod @staticmethod
def write_tx_output( def write_tx_output(
w: writers.Writer, w: Writer,
txo: TxOutput | PrevOutput, txo: TxOutput | PrevOutput,
script_pubkey: bytes, script_pubkey: bytes,
) -> None: ) -> None:
@ -879,7 +886,7 @@ class Bitcoin:
def write_tx_header( def write_tx_header(
self, self,
w: writers.Writer, w: Writer,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
witness_marker: bool, witness_marker: bool,
) -> None: ) -> None:
@ -888,21 +895,25 @@ class Bitcoin:
write_compact_size(w, 0x00) # segwit witness marker write_compact_size(w, 0x00) # segwit witness marker
write_compact_size(w, 0x01) # segwit witness flag write_compact_size(w, 0x01) # segwit witness flag
def write_tx_footer(self, w: writers.Writer, tx: SignTx | PrevTx) -> None: def write_tx_footer(self, w: Writer, tx: SignTx | PrevTx) -> None:
writers.write_uint32(w, tx.lock_time) writers.write_uint32(w, tx.lock_time)
async def write_prev_tx_footer( async def write_prev_tx_footer(
self, w: writers.Writer, tx: PrevTx, prev_hash: bytes self, w: Writer, tx: PrevTx, prev_hash: bytes
) -> None: ) -> None:
self.write_tx_footer(w, tx) self.write_tx_footer(w, tx)
def set_serialized_signature(self, index: int, signature: bytes) -> None: def set_serialized_signature(self, index: int, signature: bytes) -> None:
# Only one signature per TxRequest can be serialized. from trezor.utils import ensure
assert self.tx_req.serialized is not None
ensure(self.tx_req.serialized.signature is None)
self.tx_req.serialized.signature_index = index serialized = self.tx_req.serialized # local_cache_attribute
self.tx_req.serialized.signature = signature
# Only one signature per TxRequest can be serialized.
assert serialized is not None
ensure(serialized.signature is None)
serialized.signature_index = index
serialized.signature = signature
# scriptPubKey derivation # scriptPubKey derivation
# === # ===
@ -911,7 +922,7 @@ class Bitcoin:
self, txi: TxInput, node: bip32.HDNode | None = None self, txi: TxInput, node: bip32.HDNode | None = None
) -> bytes: ) -> bytes:
if input_is_external(txi): if input_is_external(txi):
assert txi.script_pubkey is not None # checked in sanitize_tx_input assert txi.script_pubkey is not None # checked in _sanitize_tx_input
return txi.script_pubkey return txi.script_pubkey
if node is None: if node is None:
@ -921,8 +932,10 @@ class Bitcoin:
return scripts.output_derive_script(address, self.coin) return scripts.output_derive_script(address, self.coin)
def output_derive_script(self, txo: TxOutput) -> bytes: def output_derive_script(self, txo: TxOutput) -> bytes:
from trezor.enums import OutputScriptType
if txo.script_type == OutputScriptType.PAYTOOPRETURN: if txo.script_type == OutputScriptType.PAYTOOPRETURN:
assert txo.op_return_data is not None # checked in sanitize_tx_output assert txo.op_return_data is not None # checked in _sanitize_tx_output
return scripts.output_script_paytoopreturn(txo.op_return_data) return scripts.output_script_paytoopreturn(txo.op_return_data)
if txo.address_n: if txo.address_n:
@ -932,12 +945,12 @@ class Bitcoin:
txo.script_type txo.script_type
] ]
except KeyError: except KeyError:
raise wire.DataError("Invalid script type") raise DataError("Invalid script type")
node = self.keychain.derive(txo.address_n) node = self.keychain.derive(txo.address_n)
txo.address = addresses.get_address( txo.address = addresses.get_address(
input_script_type, self.coin, node, txo.multisig input_script_type, self.coin, node, txo.multisig
) )
assert txo.address is not None # checked in sanitize_tx_output assert txo.address is not None # checked in _sanitize_tx_output
return scripts.output_derive_script(txo.address, self.coin) return scripts.output_derive_script(txo.address, self.coin)

View File

@ -1,11 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from .. import writers
from apps.common.writers import write_compact_size
from .. import multisig, writers
from ..common import NONSEGWIT_INPUT_SCRIPT_TYPES, SigHashType
from . import helpers from . import helpers
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
@ -17,6 +12,10 @@ if TYPE_CHECKING:
class Bitcoinlike(Bitcoin): class Bitcoinlike(Bitcoin):
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None: async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
from trezor import wire
from .. import multisig
from ..common import NONSEGWIT_INPUT_SCRIPT_TYPES
txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.tx_info.check_input(txi) self.tx_info.check_input(txi)
self.approver.check_internal_input(txi) self.approver.check_internal_input(txi)
@ -64,6 +63,8 @@ class Bitcoinlike(Bitcoin):
) )
def get_hash_type(self, txi: TxInput) -> int: def get_hash_type(self, txi: TxInput) -> int:
from ..common import SigHashType
hashtype = super().get_hash_type(txi) hashtype = super().get_hash_type(txi)
if self.coin.fork_id is not None: if self.coin.fork_id is not None:
hashtype |= (self.coin.fork_id << 8) | SigHashType.SIGHASH_FORKID hashtype |= (self.coin.fork_id << 8) | SigHashType.SIGHASH_FORKID
@ -75,6 +76,8 @@ class Bitcoinlike(Bitcoin):
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
witness_marker: bool, witness_marker: bool,
) -> None: ) -> None:
from apps.common.writers import write_compact_size
writers.write_uint32(w, tx.version) # nVersion writers.write_uint32(w, tx.version) # nVersion
if self.coin.timestamp: if self.coin.timestamp:
assert tx.timestamp is not None # checked in sanitize_* assert tx.timestamp is not None # checked in sanitize_*

View File

@ -1,18 +1,18 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.hashlib import blake256 from trezor.crypto.hashlib import blake256
from trezor.enums import DecredStakingSpendType, InputScriptType from trezor.enums import InputScriptType
from trezor.messages import PrevOutput from trezor.utils import HashWriter
from trezor.utils import HashWriter, ensure from trezor.wire import DataError, ProcessError
from apps.bitcoin.sign_tx.tx_weight import TxWeightCalculator from apps.bitcoin.sign_tx.tx_weight import TxWeightCalculator
from apps.common.writers import write_compact_size from apps.common.writers import write_compact_size
from .. import multisig, scripts_decred, writers from .. import scripts_decred, writers
from ..common import SigHashType, ecdsa_hash_pubkey, ecdsa_sign from ..common import ecdsa_hash_pubkey
from . import approvers, helpers from ..writers import write_uint32
from . import helpers
from .approvers import BasicApprover from .approvers import BasicApprover
from .bitcoin import Bitcoin from .bitcoin import Bitcoin
from .progress import progress from .progress import progress
@ -35,12 +35,16 @@ if TYPE_CHECKING:
TxOutput, TxOutput,
PrevTx, PrevTx,
PrevInput, PrevInput,
PrevOutput,
) )
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .sig_hasher import SigHasher from .sig_hasher import SigHasher
from . import approvers
from ..common import SigHashType
from ..writers import Writer
# Decred input size (without script): 32 prevhash, 4 idx, 1 Decred tree, 4 sequence # Decred input size (without script): 32 prevhash, 4 idx, 1 Decred tree, 4 sequence
@ -144,6 +148,8 @@ class Decred(Bitcoin):
coin: CoinInfo, coin: CoinInfo,
approver: approvers.Approver | None, approver: approvers.Approver | None,
) -> None: ) -> None:
from trezor.utils import ensure
ensure(coin.decred) ensure(coin.decred)
self.h_prefix = HashWriter(blake256()) self.h_prefix = HashWriter(blake256())
@ -151,16 +157,14 @@ class Decred(Bitcoin):
approver = DecredApprover(tx, coin) approver = DecredApprover(tx, coin)
super().__init__(tx, keychain, coin, approver) super().__init__(tx, keychain, coin, approver)
if self.serialize: tx = self.tx_info.tx # local_cache_attribute
self.write_tx_header(
self.serialized_tx, self.tx_info.tx, witness_marker=True
)
write_compact_size(self.serialized_tx, self.tx_info.tx.inputs_count)
writers.write_uint32( if self.serialize:
self.h_prefix, self.tx_info.tx.version | _DECRED_SERIALIZE_NO_WITNESS self.write_tx_header(self.serialized_tx, tx, witness_marker=True)
) write_compact_size(self.serialized_tx, tx.inputs_count)
write_compact_size(self.h_prefix, self.tx_info.tx.inputs_count)
write_uint32(self.h_prefix, tx.version | _DECRED_SERIALIZE_NO_WITNESS)
write_compact_size(self.h_prefix, tx.inputs_count)
def create_hash_writer(self) -> HashWriter: def create_hash_writer(self) -> HashWriter:
return HashWriter(blake256()) return HashWriter(blake256())
@ -169,18 +173,20 @@ class Decred(Bitcoin):
return DecredSigHasher(self.h_prefix) return DecredSigHasher(self.h_prefix)
async def step2_approve_outputs(self) -> None: async def step2_approve_outputs(self) -> None:
write_compact_size(self.h_prefix, self.tx_info.tx.outputs_count) tx = self.tx_info.tx # local_cache_attribute
if self.serialize:
write_compact_size(self.serialized_tx, self.tx_info.tx.outputs_count)
if self.tx_info.tx.decred_staking_ticket: write_compact_size(self.h_prefix, tx.outputs_count)
if self.serialize:
write_compact_size(self.serialized_tx, tx.outputs_count)
if tx.decred_staking_ticket:
await self.approve_staking_ticket() await self.approve_staking_ticket()
else: else:
await super().step2_approve_outputs() await super().step2_approve_outputs()
self.write_tx_footer(self.h_prefix, self.tx_info.tx) self.write_tx_footer(self.h_prefix, tx)
if self.serialize: if self.serialize:
self.write_tx_footer(self.serialized_tx, self.tx_info.tx) self.write_tx_footer(self.serialized_tx, tx)
async def process_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None: async def process_internal_input(self, txi: TxInput, node: bip32.HDNode) -> None:
await super().process_internal_input(txi, node) await super().process_internal_input(txi, node)
@ -190,10 +196,10 @@ class Decred(Bitcoin):
self.write_tx_input(self.serialized_tx, txi, bytes()) self.write_tx_input(self.serialized_tx, txi, bytes())
async def process_external_input(self, txi: TxInput) -> None: async def process_external_input(self, txi: TxInput) -> None:
raise wire.DataError("External inputs not supported") raise DataError("External inputs not supported")
async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None: async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None:
raise wire.DataError("Replacement transactions not supported") raise DataError("Replacement transactions not supported")
async def approve_output( async def approve_output(
self, self,
@ -206,15 +212,23 @@ class Decred(Bitcoin):
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
async def step4_serialize_inputs(self) -> None: async def step4_serialize_inputs(self) -> None:
from trezor.enums import DecredStakingSpendType
from ..common import SigHashType, ecdsa_sign
from .progress import progress
from .. import multisig
inputs_count = self.tx_info.tx.inputs_count # local_cache_attribute
coin = self.coin # local_cache_attribute
if self.serialize: if self.serialize:
write_compact_size(self.serialized_tx, self.tx_info.tx.inputs_count) write_compact_size(self.serialized_tx, inputs_count)
prefix_hash = self.h_prefix.get_digest() prefix_hash = self.h_prefix.get_digest()
for i_sign in range(self.tx_info.tx.inputs_count): for i_sign in range(inputs_count):
progress.advance() progress.advance()
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, coin)
self.tx_info.check_input(txi_sign) self.tx_info.check_input(txi_sign)
@ -222,20 +236,20 @@ class Decred(Bitcoin):
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
h_witness = self.create_hash_writer() h_witness = self.create_hash_writer()
writers.write_uint32( write_uint32(
h_witness, self.tx_info.tx.version | _DECRED_SERIALIZE_WITNESS_SIGNING h_witness, self.tx_info.tx.version | _DECRED_SERIALIZE_WITNESS_SIGNING
) )
write_compact_size(h_witness, self.tx_info.tx.inputs_count) write_compact_size(h_witness, inputs_count)
for ii in range(self.tx_info.tx.inputs_count): for ii in range(inputs_count):
if ii == i_sign: if ii == i_sign:
if txi_sign.decred_staking_spend == DecredStakingSpendType.SSRTX: if txi_sign.decred_staking_spend == DecredStakingSpendType.SSRTX:
scripts_decred.write_output_script_ssrtx_prefixed( scripts_decred.write_output_script_ssrtx_prefixed(
h_witness, ecdsa_hash_pubkey(key_sign_pub, self.coin) h_witness, ecdsa_hash_pubkey(key_sign_pub, coin)
) )
elif txi_sign.decred_staking_spend == DecredStakingSpendType.SSGen: elif txi_sign.decred_staking_spend == DecredStakingSpendType.SSGen:
scripts_decred.write_output_script_ssgen_prefixed( scripts_decred.write_output_script_ssgen_prefixed(
h_witness, ecdsa_hash_pubkey(key_sign_pub, self.coin) h_witness, ecdsa_hash_pubkey(key_sign_pub, coin)
) )
elif txi_sign.script_type == InputScriptType.SPENDMULTISIG: elif txi_sign.script_type == InputScriptType.SPENDMULTISIG:
assert txi_sign.multisig is not None assert txi_sign.multisig is not None
@ -248,24 +262,24 @@ class Decred(Bitcoin):
elif txi_sign.script_type == InputScriptType.SPENDADDRESS: elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
scripts_decred.write_output_script_p2pkh( scripts_decred.write_output_script_p2pkh(
h_witness, h_witness,
ecdsa_hash_pubkey(key_sign_pub, self.coin), ecdsa_hash_pubkey(key_sign_pub, coin),
prefixed=True, prefixed=True,
) )
else: else:
raise wire.DataError("Unsupported input script type") raise DataError("Unsupported input script type")
else: else:
write_compact_size(h_witness, 0) write_compact_size(h_witness, 0)
witness_hash = writers.get_tx_hash( witness_hash = writers.get_tx_hash(
h_witness, double=self.coin.sign_hash_double, reverse=False h_witness, double=coin.sign_hash_double, reverse=False
) )
h_sign = self.create_hash_writer() h_sign = self.create_hash_writer()
writers.write_uint32(h_sign, SigHashType.SIGHASH_ALL) write_uint32(h_sign, SigHashType.SIGHASH_ALL)
writers.write_bytes_fixed(h_sign, prefix_hash, writers.TX_HASH_SIZE) writers.write_bytes_fixed(h_sign, prefix_hash, writers.TX_HASH_SIZE)
writers.write_bytes_fixed(h_sign, witness_hash, writers.TX_HASH_SIZE) writers.write_bytes_fixed(h_sign, witness_hash, writers.TX_HASH_SIZE)
sig_hash = writers.get_tx_hash(h_sign, double=self.coin.sign_hash_double) sig_hash = writers.get_tx_hash(h_sign, double=coin.sign_hash_double)
signature = ecdsa_sign(key_sign, sig_hash) signature = ecdsa_sign(key_sign, sig_hash)
# serialize input with correct signature # serialize input with correct signature
@ -289,29 +303,31 @@ class Decred(Bitcoin):
def check_prevtx_output(self, txo_bin: PrevOutput) -> None: def check_prevtx_output(self, txo_bin: PrevOutput) -> None:
if txo_bin.decred_script_version != 0: if txo_bin.decred_script_version != 0:
raise wire.ProcessError("Cannot use utxo that has script_version != 0") raise ProcessError("Cannot use utxo that has script_version != 0")
@staticmethod @staticmethod
def write_tx_input( def write_tx_input(
w: writers.Writer, w: Writer,
txi: TxInput | PrevInput, txi: TxInput | PrevInput,
script: bytes, script: bytes,
) -> None: ) -> None:
writers.write_bytes_reversed(w, txi.prev_hash, writers.TX_HASH_SIZE) writers.write_bytes_reversed(w, txi.prev_hash, writers.TX_HASH_SIZE)
writers.write_uint32(w, txi.prev_index or 0) write_uint32(w, txi.prev_index or 0)
writers.write_uint8(w, txi.decred_tree or 0) writers.write_uint8(w, txi.decred_tree or 0)
writers.write_uint32(w, txi.sequence) write_uint32(w, txi.sequence)
@staticmethod @staticmethod
def write_tx_output( def write_tx_output(
w: writers.Writer, w: Writer,
txo: TxOutput | PrevOutput, txo: TxOutput | PrevOutput,
script_pubkey: bytes, script_pubkey: bytes,
) -> None: ) -> None:
from trezor.messages import PrevOutput
writers.write_uint64(w, txo.amount) writers.write_uint64(w, txo.amount)
if PrevOutput.is_type_of(txo): if PrevOutput.is_type_of(txo):
if txo.decred_script_version is None: if txo.decred_script_version is None:
raise wire.DataError("Script version must be provided") raise DataError("Script version must be provided")
writers.write_uint16(w, txo.decred_script_version) writers.write_uint16(w, txo.decred_script_version)
else: else:
writers.write_uint16(w, _DECRED_SCRIPT_VERSION) writers.write_uint16(w, _DECRED_SCRIPT_VERSION)
@ -319,7 +335,7 @@ class Decred(Bitcoin):
def process_sstx_commitment_owned(self, txo: TxOutput) -> bytearray: def process_sstx_commitment_owned(self, txo: TxOutput) -> bytearray:
if not self.tx_info.output_is_change(txo): if not self.tx_info.output_is_change(txo):
raise wire.DataError("Invalid sstxcommitment path.") raise DataError("Invalid sstxcommitment path.")
node = self.keychain.derive(txo.address_n) node = self.keychain.derive(txo.address_n)
pkh = ecdsa_hash_pubkey(node.public_key(), self.coin) pkh = ecdsa_hash_pubkey(node.public_key(), self.coin)
op_return_data = scripts_decred.sstxcommitment_pkh(pkh, txo.amount) op_return_data = scripts_decred.sstxcommitment_pkh(pkh, txo.amount)
@ -327,30 +343,33 @@ class Decred(Bitcoin):
return scripts_decred.output_script_paytoopreturn(op_return_data) return scripts_decred.output_script_paytoopreturn(op_return_data)
async def approve_staking_ticket(self) -> None: async def approve_staking_ticket(self) -> None:
assert isinstance(self.approver, DecredApprover) approver = self.approver # local_cache_attribute
tx_info = self.tx_info # local_cache_attribute
if self.tx_info.tx.outputs_count != 3: assert isinstance(approver, DecredApprover)
raise wire.DataError("Ticket has wrong number of outputs.")
if tx_info.tx.outputs_count != 3:
raise DataError("Ticket has wrong number of outputs.")
# SSTX submission # SSTX submission
progress.advance() progress.advance()
txo = await helpers.request_tx_output(self.tx_req, 0, self.coin) txo = await helpers.request_tx_output(self.tx_req, 0, self.coin)
if txo.address is None: if txo.address is None:
raise wire.DataError("Missing address.") raise DataError("Missing address.")
script_pubkey = scripts_decred.output_script_sstxsubmissionpkh(txo.address) script_pubkey = scripts_decred.output_script_sstxsubmissionpkh(txo.address)
await self.approver.add_decred_sstx_submission(txo, script_pubkey) await approver.add_decred_sstx_submission(txo, script_pubkey)
self.tx_info.add_output(txo, script_pubkey) tx_info.add_output(txo, script_pubkey)
if self.serialize: if self.serialize:
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
# SSTX commitment # SSTX commitment
progress.advance() progress.advance()
txo = await helpers.request_tx_output(self.tx_req, 1, self.coin) txo = await helpers.request_tx_output(self.tx_req, 1, self.coin)
if txo.amount != self.approver.total_in: if txo.amount != approver.total_in:
raise wire.DataError("Wrong sstxcommitment amount.") raise DataError("Wrong sstxcommitment amount.")
script_pubkey = self.process_sstx_commitment_owned(txo) script_pubkey = self.process_sstx_commitment_owned(txo)
self.approver.add_change_output(txo, script_pubkey) approver.add_change_output(txo, script_pubkey)
self.tx_info.add_output(txo, script_pubkey) tx_info.add_output(txo, script_pubkey)
if self.serialize: if self.serialize:
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
@ -358,23 +377,23 @@ class Decred(Bitcoin):
progress.advance() progress.advance()
txo = await helpers.request_tx_output(self.tx_req, 2, self.coin) txo = await helpers.request_tx_output(self.tx_req, 2, self.coin)
if txo.address is None: if txo.address is None:
raise wire.DataError("Missing address.") raise DataError("Missing address.")
script_pubkey = scripts_decred.output_script_sstxchange(txo.address) script_pubkey = scripts_decred.output_script_sstxchange(txo.address)
# Using change addresses is no longer common practice. Inputs are split # Using change addresses is no longer common practice. Inputs are split
# beforehand and should be exact. SSTX change should pay zero amount to # beforehand and should be exact. SSTX change should pay zero amount to
# a zeroed hash. # a zeroed hash.
if txo.amount != 0: if txo.amount != 0:
raise wire.DataError("Only value of 0 allowed for sstx change.") raise DataError("Only value of 0 allowed for sstx change.")
if script_pubkey != OUTPUT_SCRIPT_NULL_SSTXCHANGE: if script_pubkey != OUTPUT_SCRIPT_NULL_SSTXCHANGE:
raise wire.DataError("Only zeroed addresses accepted for sstx change.") raise DataError("Only zeroed addresses accepted for sstx change.")
self.approver.add_change_output(txo, script_pubkey) approver.add_change_output(txo, script_pubkey)
self.tx_info.add_output(txo, script_pubkey) tx_info.add_output(txo, script_pubkey)
if self.serialize: if self.serialize:
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
def write_tx_header( def write_tx_header(
self, self,
w: writers.Writer, w: Writer,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
witness_marker: bool, witness_marker: bool,
) -> None: ) -> None:
@ -385,19 +404,19 @@ class Decred(Bitcoin):
else: else:
version = tx.version | _DECRED_SERIALIZE_NO_WITNESS version = tx.version | _DECRED_SERIALIZE_NO_WITNESS
writers.write_uint32(w, version) write_uint32(w, version)
def write_tx_footer(self, w: writers.Writer, tx: SignTx | PrevTx) -> None: def write_tx_footer(self, w: Writer, tx: SignTx | PrevTx) -> None:
assert tx.expiry is not None # checked in sanitize_* assert tx.expiry is not None # checked in sanitize_*
writers.write_uint32(w, tx.lock_time) write_uint32(w, tx.lock_time)
writers.write_uint32(w, tx.expiry) write_uint32(w, tx.expiry)
def write_tx_input_witness( def write_tx_input_witness(
self, w: writers.Writer, txi: TxInput, pubkey: bytes, signature: bytes self, w: Writer, txi: TxInput, pubkey: bytes, signature: bytes
) -> None: ) -> None:
writers.write_uint64(w, txi.amount) writers.write_uint64(w, txi.amount)
writers.write_uint32(w, 0) # block height fraud proof write_uint32(w, 0) # block height fraud proof
writers.write_uint32(w, 0xFFFF_FFFF) # block index fraud proof write_uint32(w, 0xFFFF_FFFF) # block index fraud proof
scripts_decred.write_input_script_prefixed( scripts_decred.write_input_script_prefixed(
w, w,
txi.script_type, txi.script_type,

View File

@ -1,19 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.enums import InputScriptType, OutputScriptType, RequestType from trezor.enums import RequestType
from trezor.messages import ( from trezor.wire import DataError
TxAckInput,
TxAckOutput,
TxAckPaymentRequest,
TxAckPrevExtraData,
TxAckPrevInput,
TxAckPrevMeta,
TxAckPrevOutput,
)
from apps.common import paths
from apps.common.coininfo import CoinInfo
from .. import common from .. import common
from ..writers import TX_HASH_SIZE from ..writers import TX_HASH_SIZE
@ -22,6 +11,7 @@ from . import layout
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Awaitable from typing import Any, Awaitable
from trezor.enums import AmountUnit from trezor.enums import AmountUnit
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
PrevInput, PrevInput,
@ -31,14 +21,16 @@ if TYPE_CHECKING:
TxInput, TxInput,
TxOutput, TxOutput,
TxRequest, TxRequest,
TxAckPaymentRequest,
) )
from apps.common.coininfo import CoinInfo
# Machine instructions # Machine instructions
# === # ===
class UiConfirm: class UiConfirm:
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
raise NotImplementedError raise NotImplementedError
__eq__ = utils.obj_eq __eq__ = utils.obj_eq
@ -50,7 +42,7 @@ class UiConfirmOutput(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_output(ctx, self.output, self.coin, self.amount_unit) return layout.confirm_output(ctx, self.output, self.coin, self.amount_unit)
@ -60,7 +52,7 @@ class UiConfirmDecredSSTXSubmission(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_decred_sstx_submission( return layout.confirm_decred_sstx_submission(
ctx, self.output, self.coin, self.amount_unit ctx, self.output, self.coin, self.amount_unit
) )
@ -77,7 +69,7 @@ class UiConfirmPaymentRequest(UiConfirm):
self.amount_unit = amount_unit self.amount_unit = amount_unit
self.coin = coin self.coin = coin
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_payment_request( return layout.confirm_payment_request(
ctx, self.payment_req, self.coin, self.amount_unit ctx, self.payment_req, self.coin, self.amount_unit
) )
@ -90,7 +82,7 @@ class UiConfirmReplacement(UiConfirm):
self.description = description self.description = description
self.txid = txid self.txid = txid
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_replacement(ctx, self.description, self.txid) return layout.confirm_replacement(ctx, self.description, self.txid)
@ -107,7 +99,7 @@ class UiConfirmModifyOutput(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_modify_output( return layout.confirm_modify_output(
ctx, self.txo, self.orig_txo, self.coin, self.amount_unit ctx, self.txo, self.orig_txo, self.coin, self.amount_unit
) )
@ -128,7 +120,7 @@ class UiConfirmModifyFee(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_modify_fee( return layout.confirm_modify_fee(
ctx, ctx,
self.user_fee_change, self.user_fee_change,
@ -154,7 +146,7 @@ class UiConfirmTotal(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_total( return layout.confirm_total(
ctx, self.spending, self.fee, self.fee_rate, self.coin, self.amount_unit ctx, self.spending, self.fee, self.fee_rate, self.coin, self.amount_unit
) )
@ -169,7 +161,7 @@ class UiConfirmJointTotal(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_joint_total( return layout.confirm_joint_total(
ctx, self.spending, self.total, self.coin, self.amount_unit ctx, self.spending, self.total, self.coin, self.amount_unit
) )
@ -181,7 +173,7 @@ class UiConfirmFeeOverThreshold(UiConfirm):
self.coin = coin self.coin = coin
self.amount_unit = amount_unit self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_feeoverthreshold( return layout.confirm_feeoverthreshold(
ctx, self.fee, self.coin, self.amount_unit ctx, self.fee, self.coin, self.amount_unit
) )
@ -191,12 +183,12 @@ class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int): def __init__(self, change_count: int):
self.change_count = change_count self.change_count = change_count
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_change_count_over_threshold(ctx, self.change_count) return layout.confirm_change_count_over_threshold(ctx, self.change_count)
class UiConfirmUnverifiedExternalInput(UiConfirm): class UiConfirmUnverifiedExternalInput(UiConfirm):
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_unverified_external_input(ctx) return layout.confirm_unverified_external_input(ctx)
@ -204,7 +196,9 @@ class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list): def __init__(self, address_n: list):
self.address_n = address_n self.address_n = address_n
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
from apps.common import paths
return paths.show_path_warning(ctx, self.address_n) return paths.show_path_warning(ctx, self.address_n)
@ -213,7 +207,7 @@ class UiConfirmNonDefaultLocktime(UiConfirm):
self.lock_time = lock_time self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled self.lock_time_disabled = lock_time_disabled
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]: def confirm_dialog(self, ctx: Context) -> Awaitable[Any]:
return layout.confirm_nondefault_locktime( return layout.confirm_nondefault_locktime(
ctx, self.lock_time, self.lock_time_disabled ctx, self.lock_time, self.lock_time_disabled
) )
@ -276,28 +270,36 @@ def confirm_nondefault_locktime(lock_time: int, lock_time_disabled: bool) -> Awa
def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevTx]: # type: ignore [awaitable-is-generator] def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevTx]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckPrevMeta
assert tx_req.details is not None assert tx_req.details is not None
tx_req.request_type = RequestType.TXMETA tx_req.request_type = RequestType.TXMETA
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevMeta, tx_req ack = yield TxAckPrevMeta, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return sanitize_tx_meta(ack.tx, coin) return _sanitize_tx_meta(ack.tx, coin)
def request_tx_extra_data( def request_tx_extra_data(
tx_req: TxRequest, offset: int, size: int, tx_hash: bytes | None = None tx_req: TxRequest, offset: int, size: int, tx_hash: bytes | None = None
) -> Awaitable[bytearray]: # type: ignore [awaitable-is-generator] ) -> Awaitable[bytearray]: # type: ignore [awaitable-is-generator]
assert tx_req.details is not None from trezor.messages import TxAckPrevExtraData
details = tx_req.details # local_cache_attribute
assert details is not None
tx_req.request_type = RequestType.TXEXTRADATA tx_req.request_type = RequestType.TXEXTRADATA
tx_req.details.extra_data_offset = offset details.extra_data_offset = offset
tx_req.details.extra_data_len = size details.extra_data_len = size
tx_req.details.tx_hash = tx_hash details.tx_hash = tx_hash
ack = yield TxAckPrevExtraData, tx_req ack = yield TxAckPrevExtraData, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return ack.tx.extra_data_chunk return ack.tx.extra_data_chunk
def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxInput]: # type: ignore [awaitable-is-generator] def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxInput]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckInput
assert tx_req.details is not None assert tx_req.details is not None
if tx_hash: if tx_hash:
tx_req.request_type = RequestType.TXORIGINPUT tx_req.request_type = RequestType.TXORIGINPUT
@ -307,20 +309,24 @@ def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes |
tx_req.details.request_index = i tx_req.details.request_index = i
ack = yield TxAckInput, tx_req ack = yield TxAckInput, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return sanitize_tx_input(ack.tx.input, coin) return _sanitize_tx_input(ack.tx.input, coin)
def request_tx_prev_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevInput]: # type: ignore [awaitable-is-generator] def request_tx_prev_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevInput]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckPrevInput
assert tx_req.details is not None assert tx_req.details is not None
tx_req.request_type = RequestType.TXINPUT tx_req.request_type = RequestType.TXINPUT
tx_req.details.request_index = i tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevInput, tx_req ack = yield TxAckPrevInput, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return sanitize_tx_prev_input(ack.tx.input, coin) return _sanitize_tx_prev_input(ack.tx.input, coin)
def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxOutput]: # type: ignore [awaitable-is-generator] def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxOutput]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckOutput
assert tx_req.details is not None assert tx_req.details is not None
if tx_hash: if tx_hash:
tx_req.request_type = RequestType.TXORIGOUTPUT tx_req.request_type = RequestType.TXORIGOUTPUT
@ -330,10 +336,12 @@ def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes
tx_req.details.request_index = i tx_req.details.request_index = i
ack = yield TxAckOutput, tx_req ack = yield TxAckOutput, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return sanitize_tx_output(ack.tx.output, coin) return _sanitize_tx_output(ack.tx.output, coin)
def request_tx_prev_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevOutput]: # type: ignore [awaitable-is-generator] def request_tx_prev_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevOutput]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckPrevOutput
assert tx_req.details is not None assert tx_req.details is not None
tx_req.request_type = RequestType.TXOUTPUT tx_req.request_type = RequestType.TXOUTPUT
tx_req.details.request_index = i tx_req.details.request_index = i
@ -345,12 +353,14 @@ def request_tx_prev_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: b
def request_payment_req(tx_req: TxRequest, i: int) -> Awaitable[TxAckPaymentRequest]: # type: ignore [awaitable-is-generator] def request_payment_req(tx_req: TxRequest, i: int) -> Awaitable[TxAckPaymentRequest]: # type: ignore [awaitable-is-generator]
from trezor.messages import TxAckPaymentRequest
assert tx_req.details is not None assert tx_req.details is not None
tx_req.request_type = RequestType.TXPAYMENTREQ tx_req.request_type = RequestType.TXPAYMENTREQ
tx_req.details.request_index = i tx_req.details.request_index = i
ack = yield TxAckPaymentRequest, tx_req ack = yield TxAckPaymentRequest, tx_req
_clear_tx_request(tx_req) _clear_tx_request(tx_req)
return sanitize_payment_req(ack) return _sanitize_payment_req(ack)
def request_tx_finish(tx_req: TxRequest) -> Awaitable[None]: # type: ignore [awaitable-is-generator] def request_tx_finish(tx_req: TxRequest) -> Awaitable[None]: # type: ignore [awaitable-is-generator]
@ -360,19 +370,22 @@ def request_tx_finish(tx_req: TxRequest) -> Awaitable[None]: # type: ignore [aw
def _clear_tx_request(tx_req: TxRequest) -> None: def _clear_tx_request(tx_req: TxRequest) -> None:
assert tx_req.details is not None details = tx_req.details # local_cache_attribute
assert tx_req.serialized is not None serialized = tx_req.serialized # local_cache_attribute
assert tx_req.serialized.serialized_tx is not None
assert details is not None
assert serialized is not None
assert serialized.serialized_tx is not None
tx_req.request_type = None tx_req.request_type = None
tx_req.details.request_index = None details.request_index = None
tx_req.details.tx_hash = None details.tx_hash = None
tx_req.details.extra_data_len = None details.extra_data_len = None
tx_req.details.extra_data_offset = None details.extra_data_offset = None
tx_req.serialized.signature = None serialized.signature = None
tx_req.serialized.signature_index = None serialized.signature_index = None
# typechecker thinks serialized_tx is `bytes`, which is immutable # typechecker thinks serialized_tx is `bytes`, which is immutable
# we know that it is `bytearray` in reality # we know that it is `bytearray` in reality
tx_req.serialized.serialized_tx[:] = bytes() # type: ignore ["__setitem__" method not defined on type "bytes"] serialized.serialized_tx[:] = bytes() # type: ignore ["__setitem__" method not defined on type "bytes"]
# Data sanitizers # Data sanitizers
@ -383,149 +396,158 @@ def sanitize_sign_tx(tx: SignTx, coin: CoinInfo) -> SignTx:
if coin.decred or coin.overwintered: if coin.decred or coin.overwintered:
tx.expiry = tx.expiry if tx.expiry is not None else 0 tx.expiry = tx.expiry if tx.expiry is not None else 0
elif tx.expiry: elif tx.expiry:
raise wire.DataError("Expiry not enabled on this coin.") raise DataError("Expiry not enabled on this coin.")
if coin.timestamp and not tx.timestamp: if coin.timestamp and not tx.timestamp:
raise wire.DataError("Timestamp must be set.") raise DataError("Timestamp must be set.")
elif not coin.timestamp and tx.timestamp: elif not coin.timestamp and tx.timestamp:
raise wire.DataError("Timestamp not enabled on this coin.") raise DataError("Timestamp not enabled on this coin.")
if coin.overwintered: if coin.overwintered:
if tx.version_group_id is None: if tx.version_group_id is None:
raise wire.DataError("Version group ID must be set.") raise DataError("Version group ID must be set.")
if tx.branch_id is None: if tx.branch_id is None:
raise wire.DataError("Branch ID must be set.") raise DataError("Branch ID must be set.")
elif not coin.overwintered: elif not coin.overwintered:
if tx.version_group_id is not None: if tx.version_group_id is not None:
raise wire.DataError("Version group ID not enabled on this coin.") raise DataError("Version group ID not enabled on this coin.")
if tx.branch_id is not None: if tx.branch_id is not None:
raise wire.DataError("Branch ID not enabled on this coin.") raise DataError("Branch ID not enabled on this coin.")
return tx return tx
def sanitize_tx_meta(tx: PrevTx, coin: CoinInfo) -> PrevTx: def _sanitize_tx_meta(tx: PrevTx, coin: CoinInfo) -> PrevTx:
if not coin.extra_data and tx.extra_data_len: if not coin.extra_data and tx.extra_data_len:
raise wire.DataError("Extra data not enabled on this coin.") raise DataError("Extra data not enabled on this coin.")
if coin.decred or coin.overwintered: if coin.decred or coin.overwintered:
tx.expiry = tx.expiry if tx.expiry is not None else 0 tx.expiry = tx.expiry if tx.expiry is not None else 0
elif tx.expiry: elif tx.expiry:
raise wire.DataError("Expiry not enabled on this coin.") raise DataError("Expiry not enabled on this coin.")
if coin.timestamp and not tx.timestamp: if coin.timestamp and not tx.timestamp:
raise wire.DataError("Timestamp must be set.") raise DataError("Timestamp must be set.")
elif not coin.timestamp and tx.timestamp: elif not coin.timestamp and tx.timestamp:
raise wire.DataError("Timestamp not enabled on this coin.") raise DataError("Timestamp not enabled on this coin.")
elif not coin.overwintered: elif not coin.overwintered:
if tx.version_group_id is not None: if tx.version_group_id is not None:
raise wire.DataError("Version group ID not enabled on this coin.") raise DataError("Version group ID not enabled on this coin.")
if tx.branch_id is not None: if tx.branch_id is not None:
raise wire.DataError("Branch ID not enabled on this coin.") raise DataError("Branch ID not enabled on this coin.")
return tx return tx
def sanitize_tx_input(txi: TxInput, coin: CoinInfo) -> TxInput: def _sanitize_tx_input(txi: TxInput, coin: CoinInfo) -> TxInput:
from trezor.enums import InputScriptType
from trezor.wire import DataError # local_cache_global
script_type = txi.script_type # local_cache_attribute
if len(txi.prev_hash) != TX_HASH_SIZE: if len(txi.prev_hash) != TX_HASH_SIZE:
raise wire.DataError("Provided prev_hash is invalid.") raise DataError("Provided prev_hash is invalid.")
if txi.multisig and txi.script_type not in common.MULTISIG_INPUT_SCRIPT_TYPES: if txi.multisig and script_type not in common.MULTISIG_INPUT_SCRIPT_TYPES:
raise wire.DataError("Multisig field provided but not expected.") raise DataError("Multisig field provided but not expected.")
if not txi.multisig and txi.script_type == InputScriptType.SPENDMULTISIG: if not txi.multisig and script_type == InputScriptType.SPENDMULTISIG:
raise wire.DataError("Multisig details required.") raise DataError("Multisig details required.")
if txi.script_type in common.INTERNAL_INPUT_SCRIPT_TYPES: if script_type in common.INTERNAL_INPUT_SCRIPT_TYPES:
if not txi.address_n: if not txi.address_n:
raise wire.DataError("Missing address_n field.") raise DataError("Missing address_n field.")
if txi.script_pubkey: if txi.script_pubkey:
raise wire.DataError("Input's script_pubkey provided but not expected.") raise DataError("Input's script_pubkey provided but not expected.")
else: else:
if txi.address_n: if txi.address_n:
raise wire.DataError("Input's address_n provided but not expected.") raise DataError("Input's address_n provided but not expected.")
if not txi.script_pubkey: if not txi.script_pubkey:
raise wire.DataError("Missing script_pubkey field.") raise DataError("Missing script_pubkey field.")
if not coin.decred and txi.decred_tree is not None: if not coin.decred and txi.decred_tree is not None:
raise wire.DataError("Decred details provided but Decred coin not specified.") raise DataError("Decred details provided but Decred coin not specified.")
if txi.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES or txi.witness is not None: if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES or txi.witness is not None:
if not coin.segwit: if not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin.") raise DataError("Segwit not enabled on this coin.")
if txi.script_type == InputScriptType.SPENDTAPROOT and not coin.taproot: if script_type == InputScriptType.SPENDTAPROOT and not coin.taproot:
raise wire.DataError("Taproot not enabled on this coin") raise DataError("Taproot not enabled on this coin")
if txi.commitment_data and not txi.ownership_proof: if txi.commitment_data and not txi.ownership_proof:
raise wire.DataError("commitment_data field provided but not expected.") raise DataError("commitment_data field provided but not expected.")
if txi.orig_hash and txi.orig_index is None: if txi.orig_hash and txi.orig_index is None:
raise wire.DataError("Missing orig_index field.") raise DataError("Missing orig_index field.")
return txi return txi
def sanitize_tx_prev_input(txi: PrevInput, coin: CoinInfo) -> PrevInput: def _sanitize_tx_prev_input(txi: PrevInput, coin: CoinInfo) -> PrevInput:
if len(txi.prev_hash) != TX_HASH_SIZE: if len(txi.prev_hash) != TX_HASH_SIZE:
raise wire.DataError("Provided prev_hash is invalid.") raise DataError("Provided prev_hash is invalid.")
if not coin.decred and txi.decred_tree is not None: if not coin.decred and txi.decred_tree is not None:
raise wire.DataError("Decred details provided but Decred coin not specified.") raise DataError("Decred details provided but Decred coin not specified.")
return txi return txi
def sanitize_tx_output(txo: TxOutput, coin: CoinInfo) -> TxOutput: def _sanitize_tx_output(txo: TxOutput, coin: CoinInfo) -> TxOutput:
if txo.multisig and txo.script_type not in common.MULTISIG_OUTPUT_SCRIPT_TYPES: from trezor.enums import OutputScriptType
raise wire.DataError("Multisig field provided but not expected.") from trezor.wire import DataError # local_cache_global
if not txo.multisig and txo.script_type == OutputScriptType.PAYTOMULTISIG: script_type = txo.script_type # local_cache_attribute
raise wire.DataError("Multisig details required.") address_n = txo.address_n # local_cache_attribute
if txo.address_n and txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES: if txo.multisig and script_type not in common.MULTISIG_OUTPUT_SCRIPT_TYPES:
raise wire.DataError("Output's address_n provided but not expected.") raise DataError("Multisig field provided but not expected.")
if not txo.multisig and script_type == OutputScriptType.PAYTOMULTISIG:
raise DataError("Multisig details required.")
if address_n and script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES:
raise DataError("Output's address_n provided but not expected.")
if txo.amount is None: if txo.amount is None:
raise wire.DataError("Missing amount field.") raise DataError("Missing amount field.")
if txo.script_type in common.SEGWIT_OUTPUT_SCRIPT_TYPES: if script_type in common.SEGWIT_OUTPUT_SCRIPT_TYPES:
if not coin.segwit: if not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin.") raise DataError("Segwit not enabled on this coin.")
if txo.script_type == OutputScriptType.PAYTOTAPROOT and not coin.taproot: if script_type == OutputScriptType.PAYTOTAPROOT and not coin.taproot:
raise wire.DataError("Taproot not enabled on this coin") raise DataError("Taproot not enabled on this coin")
if txo.script_type == OutputScriptType.PAYTOOPRETURN: if script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output # op_return output
if txo.op_return_data is None: if txo.op_return_data is None:
raise wire.DataError("OP_RETURN output without op_return_data") raise DataError("OP_RETURN output without op_return_data")
if txo.amount != 0: if txo.amount != 0:
raise wire.DataError("OP_RETURN output with non-zero amount") raise DataError("OP_RETURN output with non-zero amount")
if txo.address or txo.address_n or txo.multisig: if txo.address or address_n or txo.multisig:
raise wire.DataError("OP_RETURN output with address or multisig") raise DataError("OP_RETURN output with address or multisig")
else: else:
if txo.op_return_data: if txo.op_return_data:
raise wire.DataError( raise DataError("OP RETURN data provided but not OP RETURN script type.")
"OP RETURN data provided but not OP RETURN script type." if address_n and txo.address:
) raise DataError("Both address and address_n provided.")
if txo.address_n and txo.address: if not address_n and not txo.address:
raise wire.DataError("Both address and address_n provided.") raise DataError("Missing address")
if not txo.address_n and not txo.address:
raise wire.DataError("Missing address")
if txo.orig_hash and txo.orig_index is None: if txo.orig_hash and txo.orig_index is None:
raise wire.DataError("Missing orig_index field.") raise DataError("Missing orig_index field.")
return txo return txo
def sanitize_payment_req(payment_req: TxAckPaymentRequest) -> TxAckPaymentRequest: def _sanitize_payment_req(payment_req: TxAckPaymentRequest) -> TxAckPaymentRequest:
for memo in payment_req.memos: for memo in payment_req.memos:
if (memo.text_memo, memo.refund_memo, memo.coin_purchase_memo).count(None) != 2: if (memo.text_memo, memo.refund_memo, memo.coin_purchase_memo).count(None) != 2:
raise wire.DataError( raise DataError(
"Exactly one memo type must be specified in each PaymentRequestMemo." "Exactly one memo type must be specified in each PaymentRequestMemo."
) )

View File

@ -1,15 +1,14 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import ui, utils, wire from trezor import utils
from trezor.enums import AmountUnit, ButtonRequestType, OutputScriptType from trezor.enums import ButtonRequestType
from trezor.strings import format_amount, format_timestamp from trezor.strings import format_amount
from trezor.ui import layouts from trezor.ui import layouts
from trezor.ui.layouts import confirm_metadata
from .. import addresses from .. import addresses
from ..common import format_fee_rate from ..common import format_fee_rate
from . import omni
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from trezor.ui.layouts import altcoin from trezor.ui.layouts import altcoin
@ -20,6 +19,8 @@ if TYPE_CHECKING:
from trezor.messages import TxAckPaymentRequest, TxOutput from trezor.messages import TxAckPaymentRequest, TxOutput
from trezor.ui.layouts import LayoutType from trezor.ui.layouts import LayoutType
from trezor.enums import AmountUnit
from trezor.wire import Context
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -27,6 +28,8 @@ _LOCKTIME_TIMESTAMP_MIN_VALUE = const(500_000_000)
def format_coin_amount(amount: int, coin: CoinInfo, amount_unit: AmountUnit) -> str: def format_coin_amount(amount: int, coin: CoinInfo, amount_unit: AmountUnit) -> str:
from trezor.enums import AmountUnit
decimals, shortcut = coin.decimals, coin.coin_shortcut decimals, shortcut = coin.decimals, coin.coin_shortcut
if amount_unit == AmountUnit.SATOSHI: if amount_unit == AmountUnit.SATOSHI:
decimals = 0 decimals = 0
@ -44,14 +47,18 @@ def format_coin_amount(amount: int, coin: CoinInfo, amount_unit: AmountUnit) ->
async def confirm_output( async def confirm_output(
ctx: wire.Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit ctx: Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None: ) -> None:
from trezor import ui
from . import omni
from trezor.enums import OutputScriptType
if output.script_type == OutputScriptType.PAYTOOPRETURN: if output.script_type == OutputScriptType.PAYTOOPRETURN:
data = output.op_return_data data = output.op_return_data
assert data is not None assert data is not None
if omni.is_valid(data): if omni.is_valid(data):
# OMNI transaction # OMNI transaction
layout: LayoutType = layouts.confirm_metadata( layout: LayoutType = confirm_metadata(
ctx, ctx,
"omni_transaction", "omni_transaction",
"OMNI transaction", "OMNI transaction",
@ -63,8 +70,8 @@ async def confirm_output(
layout = layouts.confirm_blob( layout = layouts.confirm_blob(
ctx, ctx,
"op_return", "op_return",
title="OP_RETURN", "OP_RETURN",
data=data, data,
br_code=ButtonRequestType.ConfirmOutput, br_code=ButtonRequestType.ConfirmOutput,
) )
else: else:
@ -89,7 +96,7 @@ async def confirm_output(
async def confirm_decred_sstx_submission( async def confirm_decred_sstx_submission(
ctx: wire.Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit ctx: Context, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None: ) -> None:
assert output.address is not None assert output.address is not None
address_short = addresses.address_short(coin, output.address) address_short = addresses.address_short(coin, output.address)
@ -100,11 +107,13 @@ async def confirm_decred_sstx_submission(
async def confirm_payment_request( async def confirm_payment_request(
ctx: wire.Context, ctx: Context,
msg: TxAckPaymentRequest, msg: TxAckPaymentRequest,
coin: CoinInfo, coin: CoinInfo,
amount_unit: AmountUnit, amount_unit: AmountUnit,
) -> Any: ) -> Any:
from trezor import wire
memo_texts = [] memo_texts = []
for m in msg.memos: for m in msg.memos:
if m.text_memo is not None: if m.text_memo is not None:
@ -126,7 +135,9 @@ async def confirm_payment_request(
) )
async def confirm_replacement(ctx: wire.Context, description: str, txid: bytes) -> None: async def confirm_replacement(ctx: Context, description: str, txid: bytes) -> None:
from ubinascii import hexlify
await layouts.confirm_replacement( await layouts.confirm_replacement(
ctx, ctx,
description, description,
@ -135,7 +146,7 @@ async def confirm_replacement(ctx: wire.Context, description: str, txid: bytes)
async def confirm_modify_output( async def confirm_modify_output(
ctx: wire.Context, ctx: Context,
txo: TxOutput, txo: TxOutput,
orig_txo: TxOutput, orig_txo: TxOutput,
coin: CoinInfo, coin: CoinInfo,
@ -154,7 +165,7 @@ async def confirm_modify_output(
async def confirm_modify_fee( async def confirm_modify_fee(
ctx: wire.Context, ctx: Context,
user_fee_change: int, user_fee_change: int,
total_fee_new: int, total_fee_new: int,
fee_rate: float, fee_rate: float,
@ -171,7 +182,7 @@ async def confirm_modify_fee(
async def confirm_joint_total( async def confirm_joint_total(
ctx: wire.Context, ctx: Context,
spending: int, spending: int,
total: int, total: int,
coin: CoinInfo, coin: CoinInfo,
@ -185,7 +196,7 @@ async def confirm_joint_total(
async def confirm_total( async def confirm_total(
ctx: wire.Context, ctx: Context,
spending: int, spending: int,
fee: int, fee: int,
fee_rate: float, fee_rate: float,
@ -194,17 +205,17 @@ async def confirm_total(
) -> None: ) -> None:
await layouts.confirm_total( await layouts.confirm_total(
ctx, ctx,
total_amount=format_coin_amount(spending, coin, amount_unit), format_coin_amount(spending, coin, amount_unit),
fee_amount=format_coin_amount(fee, coin, amount_unit), format_coin_amount(fee, coin, amount_unit),
fee_rate_amount=format_fee_rate(fee_rate, coin) if fee_rate >= 0 else None, fee_rate_amount=format_fee_rate(fee_rate, coin) if fee_rate >= 0 else None,
) )
async def confirm_feeoverthreshold( async def confirm_feeoverthreshold(
ctx: wire.Context, fee: int, coin: CoinInfo, amount_unit: AmountUnit ctx: Context, fee: int, coin: CoinInfo, amount_unit: AmountUnit
) -> None: ) -> None:
fee_amount = format_coin_amount(fee, coin, amount_unit) fee_amount = format_coin_amount(fee, coin, amount_unit)
await layouts.confirm_metadata( await confirm_metadata(
ctx, ctx,
"fee_over_threshold", "fee_over_threshold",
"High fee", "High fee",
@ -214,10 +225,8 @@ async def confirm_feeoverthreshold(
) )
async def confirm_change_count_over_threshold( async def confirm_change_count_over_threshold(ctx: Context, change_count: int) -> None:
ctx: wire.Context, change_count: int await confirm_metadata(
) -> None:
await layouts.confirm_metadata(
ctx, ctx,
"change_count_over_threshold", "change_count_over_threshold",
"Warning", "Warning",
@ -227,8 +236,8 @@ async def confirm_change_count_over_threshold(
) )
async def confirm_unverified_external_input(ctx: wire.Context) -> None: async def confirm_unverified_external_input(ctx: Context) -> None:
await layouts.confirm_metadata( await confirm_metadata(
ctx, ctx,
"unverified_external_input", "unverified_external_input",
"Warning", "Warning",
@ -238,8 +247,10 @@ async def confirm_unverified_external_input(ctx: wire.Context) -> None:
async def confirm_nondefault_locktime( async def confirm_nondefault_locktime(
ctx: wire.Context, lock_time: int, lock_time_disabled: bool ctx: Context, lock_time: int, lock_time_disabled: bool
) -> None: ) -> None:
from trezor.strings import format_timestamp
if lock_time_disabled: if lock_time_disabled:
title = "Warning" title = "Warning"
text = "Locktime is set but will\nhave no effect.\n" text = "Locktime is set but will\nhave no effect.\n"
@ -253,7 +264,7 @@ async def confirm_nondefault_locktime(
text = "Locktime for this\ntransaction is set to:\n{}" text = "Locktime for this\ntransaction is set to:\n{}"
param = format_timestamp(lock_time) param = format_timestamp(lock_time)
await layouts.confirm_metadata( await confirm_metadata(
ctx, ctx,
"nondefault_locktime", "nondefault_locktime",
title, title,

View File

@ -1,11 +1,5 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.utils import ensure
from .. import multisig
from ..common import BIP32_WALLET_DEPTH
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
@ -54,6 +48,8 @@ class MatchChecker(Generic[T]):
raise NotImplementedError raise NotImplementedError
def add_input(self, txi: TxInput) -> None: def add_input(self, txi: TxInput) -> None:
from trezor.utils import ensure
ensure(not self.read_only) ensure(not self.read_only)
if self.attribute is self.MISMATCH: if self.attribute is self.MISMATCH:
@ -68,6 +64,8 @@ class MatchChecker(Generic[T]):
self.attribute = self.MISMATCH self.attribute = self.MISMATCH
def check_input(self, txi: TxInput) -> None: def check_input(self, txi: TxInput) -> None:
from trezor import wire
if self.attribute is self.MISMATCH: if self.attribute is self.MISMATCH:
return # There was already a mismatch when adding inputs, ignore it now. return # There was already a mismatch when adding inputs, ignore it now.
@ -87,6 +85,8 @@ class MatchChecker(Generic[T]):
class WalletPathChecker(MatchChecker): class WalletPathChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any: def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
from ..common import BIP32_WALLET_DEPTH
if len(txio.address_n) < BIP32_WALLET_DEPTH: if len(txio.address_n) < BIP32_WALLET_DEPTH:
return None return None
return txio.address_n[:-BIP32_WALLET_DEPTH] return txio.address_n[:-BIP32_WALLET_DEPTH]
@ -94,6 +94,8 @@ class WalletPathChecker(MatchChecker):
class MultisigFingerprintChecker(MatchChecker): class MultisigFingerprintChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any: def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
from .. import multisig
if not txio.multisig: if not txio.multisig:
return None return None
return multisig.multisig_fingerprint(txio.multisig) return multisig.multisig_fingerprint(txio.multisig)

View File

@ -1,7 +1,4 @@
from micropython import const from micropython import const
from ustruct import unpack
from trezor.strings import format_amount
_OMNI_DECIMALS = const(8) _OMNI_DECIMALS = const(8)
@ -18,6 +15,9 @@ def is_valid(data: bytes) -> bool:
def parse(data: bytes) -> str: def parse(data: bytes) -> str:
from ustruct import unpack
from trezor.strings import format_amount
if not is_valid(data): if not is_valid(data):
raise ValueError # tried to parse data that fails validation raise ValueError # tried to parse data that fails validation
tx_version, tx_type = unpack(">HH", data[4:8]) tx_version, tx_type = unpack(">HH", data[4:8])

View File

@ -1,20 +1,14 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage import cache from trezor.wire import DataError
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from apps.common import coininfo
from apps.common.address_mac import check_address_mac
from apps.common.keychain import Keychain
from .. import writers from .. import writers
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import TxAckPaymentRequest, TxOutput from trezor.messages import TxAckPaymentRequest, TxOutput
from apps.common import coininfo
from apps.common.keychain import Keychain
_MEMO_TYPE_TEXT = const(1) _MEMO_TYPE_TEXT = const(1)
_MEMO_TYPE_REFUND = const(2) _MEMO_TYPE_REFUND = const(2)
@ -31,6 +25,12 @@ class PaymentRequestVerifier:
def __init__( def __init__(
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None: ) -> None:
from storage import cache
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from apps.common.address_mac import check_address_mac
from .. import writers # pylint: disable=import-outside-toplevel
self.h_outputs = HashWriter(sha256()) self.h_outputs = HashWriter(sha256())
self.amount = 0 self.amount = 0
self.expected_amount = msg.amount self.expected_amount = msg.amount
@ -40,12 +40,12 @@ class PaymentRequestVerifier:
if msg.nonce: if msg.nonce:
nonce = bytes(msg.nonce) nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce: if cache.get(cache.APP_COMMON_NONCE) != nonce:
raise wire.DataError("Invalid nonce in payment request.") raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE) cache.delete(cache.APP_COMMON_NONCE)
else: else:
nonce = b"" nonce = b""
if msg.memos: if msg.memos:
wire.DataError("Missing nonce in payment request.") DataError("Missing nonce in payment request.")
writers.write_bytes_fixed(self.h_pr, b"SL\x00\x24", 4) writers.write_bytes_fixed(self.h_pr, b"SL\x00\x24", 4)
writers.write_bytes_prefixed(self.h_pr, nonce) writers.write_bytes_prefixed(self.h_pr, nonce)
@ -73,8 +73,10 @@ class PaymentRequestVerifier:
writers.write_uint32(self.h_pr, coin.slip44) writers.write_uint32(self.h_pr, coin.slip44)
def verify(self) -> None: def verify(self) -> None:
from trezor.crypto.curve import secp256k1
if self.expected_amount is not None and self.amount != self.expected_amount: if self.expected_amount is not None and self.amount != self.expected_amount:
raise wire.DataError("Invalid amount in payment request.") raise DataError("Invalid amount in payment request.")
hash_outputs = writers.get_tx_hash(self.h_outputs) hash_outputs = writers.get_tx_hash(self.h_outputs)
writers.write_bytes_fixed(self.h_pr, hash_outputs, 32) writers.write_bytes_fixed(self.h_pr, hash_outputs, 32)
@ -82,7 +84,7 @@ class PaymentRequestVerifier:
if not secp256k1.verify( if not secp256k1.verify(
self.PUBLIC_KEY, self.signature, self.h_pr.get_digest() self.PUBLIC_KEY, self.signature, self.h_pr.get_digest()
): ):
raise wire.DataError("Invalid signature in payment request.") raise DataError("Invalid signature in payment request.")
def _add_output(self, txo: TxOutput) -> None: def _add_output(self, txo: TxOutput) -> None:
# For change outputs txo.address filled in by output_derive_script(). # For change outputs txo.address filled in by output_derive_script().

View File

@ -1,17 +1,18 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto.hashlib import sha256 from ..writers import (
from trezor.utils import HashWriter TX_HASH_SIZE,
write_bytes_fixed,
from apps.common import coininfo write_bytes_reversed,
write_uint32,
from .. import scripts, writers write_uint64,
from ..common import tagged_hashwriter )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Protocol, Sequence from typing import Protocol, Sequence
from ..common import SigHashType from ..common import SigHashType
from trezor.messages import PrevTx, SignTx, TxInput, TxOutput from trezor.messages import PrevTx, SignTx, TxInput, TxOutput
from apps.common import coininfo
class SigHasher(Protocol): class SigHasher(Protocol):
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None: def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
@ -50,6 +51,9 @@ if TYPE_CHECKING:
# BIP-0143 hash # BIP-0143 hash
class BitcoinSigHasher: class BitcoinSigHasher:
def __init__(self) -> None: def __init__(self) -> None:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
self.h_prevouts = HashWriter(sha256()) self.h_prevouts = HashWriter(sha256())
self.h_amounts = HashWriter(sha256()) self.h_amounts = HashWriter(sha256())
self.h_scriptpubkeys = HashWriter(sha256()) self.h_scriptpubkeys = HashWriter(sha256())
@ -57,16 +61,18 @@ class BitcoinSigHasher:
self.h_outputs = HashWriter(sha256()) self.h_outputs = HashWriter(sha256())
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None: def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
writers.write_bytes_reversed( from ..writers import write_bytes_prefixed
self.h_prevouts, txi.prev_hash, writers.TX_HASH_SIZE
) write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE)
writers.write_uint32(self.h_prevouts, txi.prev_index) write_uint32(self.h_prevouts, txi.prev_index)
writers.write_uint64(self.h_amounts, txi.amount) write_uint64(self.h_amounts, txi.amount)
writers.write_bytes_prefixed(self.h_scriptpubkeys, script_pubkey) write_bytes_prefixed(self.h_scriptpubkeys, script_pubkey)
writers.write_uint32(self.h_sequences, txi.sequence) write_uint32(self.h_sequences, txi.sequence)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
writers.write_tx_output(self.h_outputs, txo, script_pubkey) from ..writers import write_tx_output
write_tx_output(self.h_outputs, txo, script_pubkey)
def hash143( def hash143(
self, self,
@ -77,26 +83,27 @@ class BitcoinSigHasher:
coin: coininfo.CoinInfo, coin: coininfo.CoinInfo,
hash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .. import scripts
from ..writers import get_tx_hash
h_preimage = HashWriter(sha256()) h_preimage = HashWriter(sha256())
# nVersion # nVersion
writers.write_uint32(h_preimage, tx.version) write_uint32(h_preimage, tx.version)
# hashPrevouts # hashPrevouts
prevouts_hash = writers.get_tx_hash( prevouts_hash = get_tx_hash(self.h_prevouts, double=coin.sign_hash_double)
self.h_prevouts, double=coin.sign_hash_double write_bytes_fixed(h_preimage, prevouts_hash, TX_HASH_SIZE)
)
writers.write_bytes_fixed(h_preimage, prevouts_hash, writers.TX_HASH_SIZE)
# hashSequence # hashSequence
sequence_hash = writers.get_tx_hash( sequence_hash = get_tx_hash(self.h_sequences, double=coin.sign_hash_double)
self.h_sequences, double=coin.sign_hash_double write_bytes_fixed(h_preimage, sequence_hash, TX_HASH_SIZE)
)
writers.write_bytes_fixed(h_preimage, sequence_hash, writers.TX_HASH_SIZE)
# outpoint # outpoint
writers.write_bytes_reversed(h_preimage, txi.prev_hash, writers.TX_HASH_SIZE) write_bytes_reversed(h_preimage, txi.prev_hash, TX_HASH_SIZE)
writers.write_uint32(h_preimage, txi.prev_index) write_uint32(h_preimage, txi.prev_index)
# scriptCode # scriptCode
scripts.write_bip143_script_code_prefixed( scripts.write_bip143_script_code_prefixed(
@ -104,22 +111,22 @@ class BitcoinSigHasher:
) )
# amount # amount
writers.write_uint64(h_preimage, txi.amount) write_uint64(h_preimage, txi.amount)
# nSequence # nSequence
writers.write_uint32(h_preimage, txi.sequence) write_uint32(h_preimage, txi.sequence)
# hashOutputs # hashOutputs
outputs_hash = writers.get_tx_hash(self.h_outputs, double=coin.sign_hash_double) outputs_hash = get_tx_hash(self.h_outputs, double=coin.sign_hash_double)
writers.write_bytes_fixed(h_preimage, outputs_hash, writers.TX_HASH_SIZE) write_bytes_fixed(h_preimage, outputs_hash, TX_HASH_SIZE)
# nLockTime # nLockTime
writers.write_uint32(h_preimage, tx.lock_time) write_uint32(h_preimage, tx.lock_time)
# nHashType # nHashType
writers.write_uint32(h_preimage, hash_type) write_uint32(h_preimage, hash_type)
return writers.get_tx_hash(h_preimage, double=coin.sign_hash_double) return get_tx_hash(h_preimage, double=coin.sign_hash_double)
def hash341( def hash341(
self, self,
@ -127,50 +134,43 @@ class BitcoinSigHasher:
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
sighash_type: SigHashType, sighash_type: SigHashType,
) -> bytes: ) -> bytes:
from ..common import tagged_hashwriter
from ..writers import write_uint8
h_sigmsg = tagged_hashwriter(b"TapSighash") h_sigmsg = tagged_hashwriter(b"TapSighash")
# sighash epoch 0 # sighash epoch 0
writers.write_uint8(h_sigmsg, 0) write_uint8(h_sigmsg, 0)
# nHashType # nHashType
writers.write_uint8(h_sigmsg, sighash_type & 0xFF) write_uint8(h_sigmsg, sighash_type & 0xFF)
# nVersion # nVersion
writers.write_uint32(h_sigmsg, tx.version) write_uint32(h_sigmsg, tx.version)
# nLockTime # nLockTime
writers.write_uint32(h_sigmsg, tx.lock_time) write_uint32(h_sigmsg, tx.lock_time)
# sha_prevouts # sha_prevouts
writers.write_bytes_fixed( write_bytes_fixed(h_sigmsg, self.h_prevouts.get_digest(), TX_HASH_SIZE)
h_sigmsg, self.h_prevouts.get_digest(), writers.TX_HASH_SIZE
)
# sha_amounts # sha_amounts
writers.write_bytes_fixed( write_bytes_fixed(h_sigmsg, self.h_amounts.get_digest(), TX_HASH_SIZE)
h_sigmsg, self.h_amounts.get_digest(), writers.TX_HASH_SIZE
)
# sha_scriptpubkeys # sha_scriptpubkeys
writers.write_bytes_fixed( write_bytes_fixed(h_sigmsg, self.h_scriptpubkeys.get_digest(), TX_HASH_SIZE)
h_sigmsg, self.h_scriptpubkeys.get_digest(), writers.TX_HASH_SIZE
)
# sha_sequences # sha_sequences
writers.write_bytes_fixed( write_bytes_fixed(h_sigmsg, self.h_sequences.get_digest(), TX_HASH_SIZE)
h_sigmsg, self.h_sequences.get_digest(), writers.TX_HASH_SIZE
)
# sha_outputs # sha_outputs
writers.write_bytes_fixed( write_bytes_fixed(h_sigmsg, self.h_outputs.get_digest(), TX_HASH_SIZE)
h_sigmsg, self.h_outputs.get_digest(), writers.TX_HASH_SIZE
)
# spend_type 0 (no tapscript message extension, no annex) # spend_type 0 (no tapscript message extension, no annex)
writers.write_uint8(h_sigmsg, 0) write_uint8(h_sigmsg, 0)
# input_index # input_index
writers.write_uint32(h_sigmsg, i) write_uint32(h_sigmsg, i)
return h_sigmsg.get_digest() return h_sigmsg.get_digest()

View File

@ -1,13 +1,7 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .. import common, writers from .. import common, writers
from ..common import BIP32_WALLET_DEPTH, input_is_external
from .matchcheck import MultisigFingerprintChecker, WalletPathChecker
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Protocol from typing import Protocol
@ -17,6 +11,7 @@ if TYPE_CHECKING:
TxInput, TxInput,
TxOutput, TxOutput,
) )
from trezor.utils import HashWriter
from .sig_hasher import SigHasher from .sig_hasher import SigHasher
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -61,6 +56,10 @@ _MAX_BIP125_RBF_SEQUENCE = const(0xFFFF_FFFD)
class TxInfoBase: class TxInfoBase:
def __init__(self, signer: Signer, tx: SignTx | PrevTx) -> None: def __init__(self, signer: Signer, tx: SignTx | PrevTx) -> None:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .matchcheck import MultisigFingerprintChecker, WalletPathChecker
# Checksum of multisig inputs, used to validate change-output. # Checksum of multisig inputs, used to validate change-output.
self.multisig_fingerprint = MultisigFingerprintChecker() self.multisig_fingerprint = MultisigFingerprintChecker()
@ -89,7 +88,7 @@ class TxInfoBase:
writers.write_tx_input_check(self.h_tx_check, txi) writers.write_tx_input_check(self.h_tx_check, txi)
self.min_sequence = min(self.min_sequence, txi.sequence) self.min_sequence = min(self.min_sequence, txi.sequence)
if not input_is_external(txi): if not common.input_is_external(txi):
self.wallet_path.add_input(txi) self.wallet_path.add_input(txi)
self.multisig_fingerprint.add_input(txi) self.multisig_fingerprint.add_input(txi)
@ -108,7 +107,7 @@ class TxInfoBase:
return False return False
return ( return (
self.wallet_path.output_matches(txo) self.wallet_path.output_matches(txo)
and len(txo.address_n) >= BIP32_WALLET_DEPTH and len(txo.address_n) >= common.BIP32_WALLET_DEPTH
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
and txo.amount > 0 and txo.amount > 0
@ -163,6 +162,8 @@ class OriginalTxInfo(TxInfoBase):
writers.write_tx_output(self.h_tx, txo, script_pubkey) writers.write_tx_output(self.h_tx, txo, script_pubkey)
async def finalize_tx_hash(self) -> None: async def finalize_tx_hash(self) -> None:
from trezor import wire
await self.signer.write_prev_tx_footer(self.h_tx, self.tx, self.orig_hash) await self.signer.write_prev_tx_footer(self.h_tx, self.tx, self.orig_hash)
if self.orig_hash != writers.get_tx_hash( if self.orig_hash != writers.get_tx_hash(
self.h_tx, double=self.signer.coin.sign_hash_double, reverse=True self.h_tx, double=self.signer.coin.sign_hash_double, reverse=True

View File

@ -11,7 +11,7 @@ from typing import TYPE_CHECKING
from trezor import wire from trezor import wire
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from .. import common, ownership from .. import common
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import TxInput from trezor.messages import TxInput
@ -50,28 +50,32 @@ class TxWeightCalculator:
@classmethod @classmethod
def input_script_size(cls, i: TxInput) -> int: def input_script_size(cls, i: TxInput) -> int:
script_type = i.script_type script_type = i.script_type # local_cache_attribute
script_pubkey = i.script_pubkey # local_cache_attribute
multisig = i.multisig # local_cache_attribute
IST = InputScriptType # local_cache_global
if common.input_is_external_unverified(i): if common.input_is_external_unverified(i):
assert i.script_pubkey is not None # checked in sanitize_tx_input assert script_pubkey is not None # checked in _sanitize_tx_input
# Guess the script type from the scriptPubKey. # Guess the script type from the scriptPubKey.
if i.script_pubkey[0] == 0x76: # OP_DUP (P2PKH) if script_pubkey[0] == 0x76: # OP_DUP (P2PKH)
script_type = InputScriptType.SPENDADDRESS script_type = IST.SPENDADDRESS
elif i.script_pubkey[0] == 0xA9: # OP_HASH_160 (P2SH) elif script_pubkey[0] == 0xA9: # OP_HASH_160 (P2SH)
# Probably nested P2WPKH. # Probably nested P2WPKH.
script_type = InputScriptType.SPENDP2SHWITNESS script_type = IST.SPENDP2SHWITNESS
elif i.script_pubkey[0] == 0x00: # SegWit v0 (probably P2WPKH) elif script_pubkey[0] == 0x00: # SegWit v0 (probably P2WPKH)
script_type = InputScriptType.SPENDWITNESS script_type = IST.SPENDWITNESS
elif i.script_pubkey[0] == 0x51: # SegWit v1 (P2TR) elif script_pubkey[0] == 0x51: # SegWit v1 (P2TR)
script_type = InputScriptType.SPENDTAPROOT script_type = IST.SPENDTAPROOT
else: # Unknown script type. else: # Unknown script type.
pass pass
if i.multisig: if multisig:
if script_type == InputScriptType.SPENDTAPROOT: if script_type == IST.SPENDTAPROOT:
raise wire.ProcessError("Multisig not supported for taproot") raise wire.ProcessError("Multisig not supported for taproot")
n = len(i.multisig.nodes) if i.multisig.nodes else len(i.multisig.pubkeys) n = len(multisig.nodes) if multisig.nodes else len(multisig.pubkeys)
multisig_script_size = _TXSIZE_MULTISIGSCRIPT + n * (1 + _TXSIZE_PUBKEY) multisig_script_size = _TXSIZE_MULTISIGSCRIPT + n * (1 + _TXSIZE_PUBKEY)
if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES: if script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
multisig_script_size += cls.compact_size_len(multisig_script_size) multisig_script_size += cls.compact_size_len(multisig_script_size)
@ -80,25 +84,29 @@ class TxWeightCalculator:
return ( return (
1 # the OP_FALSE bug in multisig 1 # the OP_FALSE bug in multisig
+ i.multisig.m * (1 + _TXSIZE_DER_SIGNATURE) + multisig.m * (1 + _TXSIZE_DER_SIGNATURE)
+ multisig_script_size + multisig_script_size
) )
elif script_type == InputScriptType.SPENDTAPROOT: elif script_type == IST.SPENDTAPROOT:
return 1 + _TXSIZE_SCHNORR_SIGNATURE return 1 + _TXSIZE_SCHNORR_SIGNATURE
else: else:
return 1 + _TXSIZE_DER_SIGNATURE + 1 + _TXSIZE_PUBKEY return 1 + _TXSIZE_DER_SIGNATURE + 1 + _TXSIZE_PUBKEY
def add_input(self, i: TxInput) -> None: def add_input(self, i: TxInput) -> None:
from .. import ownership
script_type = i.script_type # local_cache_attribute
self.inputs_count += 1 self.inputs_count += 1
self.counter += 4 * _TXSIZE_INPUT self.counter += 4 * _TXSIZE_INPUT
input_script_size = self.input_script_size(i) input_script_size = self.input_script_size(i)
if i.script_type in common.NONSEGWIT_INPUT_SCRIPT_TYPES: if script_type in common.NONSEGWIT_INPUT_SCRIPT_TYPES:
input_script_size += self.compact_size_len(input_script_size) input_script_size += self.compact_size_len(input_script_size)
self.counter += 4 * input_script_size self.counter += 4 * input_script_size
elif i.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES: elif script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
self.segwit_inputs_count += 1 self.segwit_inputs_count += 1
if i.script_type == InputScriptType.SPENDP2SHWITNESS: if script_type == InputScriptType.SPENDP2SHWITNESS:
# add script_sig size # add script_sig size
if i.multisig: if i.multisig:
self.counter += 4 * (2 + _TXSIZE_WITNESSSCRIPT) self.counter += 4 * (2 + _TXSIZE_WITNESSSCRIPT)
@ -107,7 +115,7 @@ class TxWeightCalculator:
else: else:
self.counter += 4 # empty script_sig (1 byte) self.counter += 4 # empty script_sig (1 byte)
self.counter += 1 + input_script_size # discounted witness self.counter += 1 + input_script_size # discounted witness
elif i.script_type == InputScriptType.EXTERNAL: elif script_type == InputScriptType.EXTERNAL:
if i.ownership_proof: if i.ownership_proof:
script_sig, witness = ownership.read_scriptsig_witness( script_sig, witness = ownership.read_scriptsig_witness(
i.ownership_proof i.ownership_proof

View File

@ -1,32 +1,19 @@
import ustruct as struct
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.hashlib import blake2b from trezor.crypto.hashlib import blake2b
from trezor.utils import HashWriter, ensure from trezor.utils import HashWriter
from trezor.wire import DataError
from apps.common.coininfo import CoinInfo from ..writers import TX_HASH_SIZE, write_bytes_reversed, write_uint32, write_uint64
from apps.common.keychain import Keychain
from apps.common.writers import write_compact_size
from ..scripts import write_bip143_script_code_prefixed
from ..writers import (
TX_HASH_SIZE,
get_tx_hash,
write_bytes_fixed,
write_bytes_reversed,
write_tx_output,
write_uint32,
write_uint64,
)
from . import approvers, helpers
from .bitcoinlike import Bitcoinlike from .bitcoinlike import Bitcoinlike
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import PrevTx, SignTx, TxInput, TxOutput from trezor.messages import PrevTx, SignTx, TxInput, TxOutput
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
from . import approvers
from typing import Sequence from typing import Sequence
from apps.common import coininfo
from .sig_hasher import SigHasher from .sig_hasher import SigHasher
from .tx_info import OriginalTxInfo, TxInfo from .tx_info import OriginalTxInfo, TxInfo
from ..common import SigHashType from ..common import SigHashType
@ -47,6 +34,8 @@ class Zip243SigHasher:
write_uint32(self.h_sequence, txi.sequence) write_uint32(self.h_sequence, txi.sequence)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None: def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
from ..writers import write_tx_output
write_tx_output(self.h_outputs, txo, script_pubkey) write_tx_output(self.h_outputs, txo, script_pubkey)
def hash143( def hash143(
@ -55,9 +44,13 @@ class Zip243SigHasher:
public_keys: Sequence[bytes | memoryview], public_keys: Sequence[bytes | memoryview],
threshold: int, threshold: int,
tx: SignTx | PrevTx, tx: SignTx | PrevTx,
coin: coininfo.CoinInfo, coin: CoinInfo,
hash_type: int, hash_type: int,
) -> bytes: ) -> bytes:
import ustruct as struct
from ..scripts import write_bip143_script_code_prefixed
from ..writers import get_tx_hash, write_bytes_fixed
h_preimage = HashWriter( h_preimage = HashWriter(
blake2b( blake2b(
outlen=32, outlen=32,
@ -129,23 +122,30 @@ class ZcashV4(Bitcoinlike):
coin: CoinInfo, coin: CoinInfo,
approver: approvers.Approver | None, approver: approvers.Approver | None,
) -> None: ) -> None:
from trezor.utils import ensure
ensure(coin.overwintered) ensure(coin.overwintered)
super().__init__(tx, keychain, coin, approver) super().__init__(tx, keychain, coin, approver)
if tx.version != 4: if tx.version != 4:
raise wire.DataError("Unsupported transaction version.") raise DataError("Unsupported transaction version.")
def create_sig_hasher(self, tx: SignTx | PrevTx) -> SigHasher: def create_sig_hasher(self, tx: SignTx | PrevTx) -> SigHasher:
return Zip243SigHasher() return Zip243SigHasher()
async def step7_finish(self) -> None: async def step7_finish(self) -> None:
if self.serialize: from apps.common.writers import write_compact_size
self.write_tx_footer(self.serialized_tx, self.tx_info.tx) from . import helpers
write_uint64(self.serialized_tx, 0) # valueBalance serialized_tx = self.serialized_tx # local_cache_attribute
write_compact_size(self.serialized_tx, 0) # nShieldedSpend
write_compact_size(self.serialized_tx, 0) # nShieldedOutput if self.serialize:
write_compact_size(self.serialized_tx, 0) # nJoinSplit self.write_tx_footer(serialized_tx, self.tx_info.tx)
write_uint64(serialized_tx, 0) # valueBalance
write_compact_size(serialized_tx, 0) # nShieldedSpend
write_compact_size(serialized_tx, 0) # nShieldedOutput
write_compact_size(serialized_tx, 0) # nJoinSplit
await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
@ -179,7 +179,7 @@ class ZcashV4(Bitcoinlike):
write_uint32(w, tx.version) write_uint32(w, tx.version)
else: else:
if tx.version_group_id is None: if tx.version_group_id is None:
raise wire.DataError("Version group ID is missing") raise DataError("Version group ID is missing")
# nVersion | fOverwintered # nVersion | fOverwintered
write_uint32(w, tx.version | _OVERWINTERED) write_uint32(w, tx.version | _OVERWINTERED)
write_uint32(w, tx.version_group_id) # nVersionGroupId write_uint32(w, tx.version_group_id) # nVersionGroupId

View File

@ -1,29 +1,11 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor.wire import DataError
from trezor.crypto import der
from trezor.crypto.curve import bip340, secp256k1
from trezor.crypto.hashlib import sha256
from .common import OP_0, OP_1, SigHashType, ecdsa_hash_pubkey
from .scripts import (
output_script_native_segwit,
output_script_p2pkh,
output_script_p2sh,
parse_input_script_multisig,
parse_input_script_p2pkh,
parse_output_script_multisig,
parse_output_script_p2tr,
parse_witness_multisig,
parse_witness_p2tr,
parse_witness_p2wpkh,
write_input_script_p2wpkh_in_p2sh,
write_input_script_p2wsh_in_p2sh,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence from typing import Sequence
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from .common import SigHashType
class SignatureVerifier: class SignatureVerifier:
@ -34,6 +16,26 @@ class SignatureVerifier:
witness: bytes | None, witness: bytes | None,
coin: CoinInfo, coin: CoinInfo,
): ):
from trezor import utils
from trezor.wire import DataError # local_cache_global
from trezor.crypto.hashlib import sha256
from .common import OP_0, OP_1, SigHashType, ecdsa_hash_pubkey
from .scripts import (
output_script_native_segwit,
output_script_p2pkh,
output_script_p2sh,
parse_input_script_multisig,
parse_input_script_p2pkh,
parse_output_script_multisig,
parse_output_script_p2tr,
parse_witness_multisig,
parse_witness_p2tr,
parse_witness_p2wpkh,
write_input_script_p2wpkh_in_p2sh,
write_input_script_p2wsh_in_p2sh,
)
self.threshold = 1 self.threshold = 1
self.public_keys: list[memoryview] = [] self.public_keys: list[memoryview] = []
self.signatures: list[tuple[memoryview, SigHashType]] = [] self.signatures: list[tuple[memoryview, SigHashType]] = []
@ -41,27 +43,27 @@ class SignatureVerifier:
if not script_sig: if not script_sig:
if not witness: if not witness:
raise wire.DataError("Signature data not provided") raise DataError("Signature data not provided")
if len(script_pubkey) == 22: # P2WPKH if len(script_pubkey) == 22: # P2WPKH
public_key, signature, hash_type = parse_witness_p2wpkh(witness) public_key, signature, hash_type = parse_witness_p2wpkh(witness)
pubkey_hash = ecdsa_hash_pubkey(public_key, coin) pubkey_hash = ecdsa_hash_pubkey(public_key, coin)
if output_script_native_segwit(0, pubkey_hash) != script_pubkey: if output_script_native_segwit(0, pubkey_hash) != script_pubkey:
raise wire.DataError("Invalid public key hash") raise DataError("Invalid public key hash")
self.public_keys = [public_key] self.public_keys = [public_key]
self.signatures = [(signature, hash_type)] self.signatures = [(signature, hash_type)]
elif len(script_pubkey) == 34 and script_pubkey[0] == OP_0: # P2WSH elif len(script_pubkey) == 34 and script_pubkey[0] == OP_0: # P2WSH
script, self.signatures = parse_witness_multisig(witness) script, self.signatures = parse_witness_multisig(witness)
script_hash = sha256(script).digest() script_hash = sha256(script).digest()
if output_script_native_segwit(0, script_hash) != script_pubkey: if output_script_native_segwit(0, script_hash) != script_pubkey:
raise wire.DataError("Invalid script hash") raise DataError("Invalid script hash")
self.public_keys, self.threshold = parse_output_script_multisig(script) self.public_keys, self.threshold = parse_output_script_multisig(script)
elif len(script_pubkey) == 34 and script_pubkey[0] == OP_1: # P2TR elif len(script_pubkey) == 34 and script_pubkey[0] == OP_1: # P2TR
self.is_taproot = True self.is_taproot = True
self.public_keys = [parse_output_script_p2tr(script_pubkey)] self.public_keys = [parse_output_script_p2tr(script_pubkey)]
self.signatures = [parse_witness_p2tr(witness)] self.signatures = [parse_witness_p2tr(witness)]
else: else:
raise wire.DataError("Unsupported signature script") raise DataError("Unsupported signature script")
elif witness and witness != b"\x00": elif witness and witness != b"\x00":
if len(script_sig) == 23: # P2WPKH nested in BIP16 P2SH if len(script_sig) == 23: # P2WPKH nested in BIP16 P2SH
public_key, signature, hash_type = parse_witness_p2wpkh(witness) public_key, signature, hash_type = parse_witness_p2wpkh(witness)
@ -69,10 +71,10 @@ class SignatureVerifier:
w = utils.empty_bytearray(23) w = utils.empty_bytearray(23)
write_input_script_p2wpkh_in_p2sh(w, pubkey_hash) write_input_script_p2wpkh_in_p2sh(w, pubkey_hash)
if w != script_sig: if w != script_sig:
raise wire.DataError("Invalid public key hash") raise DataError("Invalid public key hash")
script_hash = coin.script_hash(script_sig[1:]).digest() script_hash = coin.script_hash(script_sig[1:]).digest()
if output_script_p2sh(script_hash) != script_pubkey: if output_script_p2sh(script_hash) != script_pubkey:
raise wire.DataError("Invalid script hash") raise DataError("Invalid script hash")
self.public_keys = [public_key] self.public_keys = [public_key]
self.signatures = [(signature, hash_type)] self.signatures = [(signature, hash_type)]
elif len(script_sig) == 35: # P2WSH nested in BIP16 P2SH elif len(script_sig) == 35: # P2WSH nested in BIP16 P2SH
@ -81,36 +83,36 @@ class SignatureVerifier:
w = utils.empty_bytearray(35) w = utils.empty_bytearray(35)
write_input_script_p2wsh_in_p2sh(w, script_hash) write_input_script_p2wsh_in_p2sh(w, script_hash)
if w != script_sig: if w != script_sig:
raise wire.DataError("Invalid script hash") raise DataError("Invalid script hash")
script_hash = coin.script_hash(script_sig[1:]).digest() script_hash = coin.script_hash(script_sig[1:]).digest()
if output_script_p2sh(script_hash) != script_pubkey: if output_script_p2sh(script_hash) != script_pubkey:
raise wire.DataError("Invalid script hash") raise DataError("Invalid script hash")
self.public_keys, self.threshold = parse_output_script_multisig(script) self.public_keys, self.threshold = parse_output_script_multisig(script)
else: else:
raise wire.DataError("Unsupported signature script") raise DataError("Unsupported signature script")
else: else:
if len(script_pubkey) == 25: # P2PKH if len(script_pubkey) == 25: # P2PKH
public_key, signature, hash_type = parse_input_script_p2pkh(script_sig) public_key, signature, hash_type = parse_input_script_p2pkh(script_sig)
pubkey_hash = ecdsa_hash_pubkey(public_key, coin) pubkey_hash = ecdsa_hash_pubkey(public_key, coin)
if output_script_p2pkh(pubkey_hash) != script_pubkey: if output_script_p2pkh(pubkey_hash) != script_pubkey:
raise wire.DataError("Invalid public key hash") raise DataError("Invalid public key hash")
self.public_keys = [public_key] self.public_keys = [public_key]
self.signatures = [(signature, hash_type)] self.signatures = [(signature, hash_type)]
elif len(script_pubkey) == 23: # P2SH elif len(script_pubkey) == 23: # P2SH
script, self.signatures = parse_input_script_multisig(script_sig) script, self.signatures = parse_input_script_multisig(script_sig)
script_hash = coin.script_hash(script).digest() script_hash = coin.script_hash(script).digest()
if output_script_p2sh(script_hash) != script_pubkey: if output_script_p2sh(script_hash) != script_pubkey:
raise wire.DataError("Invalid script hash") raise DataError("Invalid script hash")
self.public_keys, self.threshold = parse_output_script_multisig(script) self.public_keys, self.threshold = parse_output_script_multisig(script)
else: else:
raise wire.DataError("Unsupported signature script") raise DataError("Unsupported signature script")
if self.threshold != len(self.signatures): if self.threshold != len(self.signatures):
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
def ensure_hash_type(self, sighash_types: Sequence[SigHashType]) -> None: def ensure_hash_type(self, sighash_types: Sequence[SigHashType]) -> None:
if any(h not in sighash_types for _, h in self.signatures): if any(h not in sighash_types for _, h in self.signatures):
raise wire.DataError("Unsupported sighash type") raise DataError("Unsupported sighash type")
def verify(self, digest: bytes) -> None: def verify(self, digest: bytes) -> None:
# It is up to the caller to ensure that the digest is appropriate for # It is up to the caller to ensure that the digest is appropriate for
@ -125,21 +127,27 @@ class SignatureVerifier:
self.verify_ecdsa(digest) self.verify_ecdsa(digest)
def verify_bip340(self, digest: bytes) -> None: def verify_bip340(self, digest: bytes) -> None:
from trezor.crypto.curve import bip340
if not bip340.verify(self.public_keys[0], self.signatures[0][0], digest): if not bip340.verify(self.public_keys[0], self.signatures[0][0], digest):
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
def verify_ecdsa(self, digest: bytes) -> None: def verify_ecdsa(self, digest: bytes) -> None:
from trezor.crypto.curve import secp256k1
try: try:
i = 0 i = 0
for der_signature, _ in self.signatures: for der_signature, _ in self.signatures:
signature = decode_der_signature(der_signature) signature = _decode_der_signature(der_signature)
while not secp256k1.verify(self.public_keys[i], signature, digest): while not secp256k1.verify(self.public_keys[i], signature, digest):
i += 1 i += 1
except Exception: except Exception:
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
def decode_der_signature(der_signature: memoryview) -> bytearray: def _decode_der_signature(der_signature: memoryview) -> bytearray:
from trezor.crypto import der
seq = der.decode_seq(der_signature) seq = der.decode_seq(der_signature)
if len(seq) != 2 or any(len(i) > 32 for i in seq): if len(seq) != 2 or any(len(i) > 32 for i in seq):
raise ValueError raise ValueError

View File

@ -1,30 +1,20 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire
from trezor.crypto import base58
from trezor.crypto.curve import secp256k1
from trezor.enums import InputScriptType
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common import address_type, coins
from apps.common.signverify import decode_message, message_digest
from . import common
from .addresses import (
address_p2wpkh,
address_p2wpkh_in_p2sh,
address_pkh,
address_short,
address_to_cashaddr,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from trezor.messages import VerifyMessage from trezor.messages import VerifyMessage, Success
from trezor.wire import Context
from trezor.enums import InputScriptType
def address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType: def _address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
from trezor.crypto import base58
from trezor.wire import DataError
from trezor.enums import InputScriptType
from trezor import utils
from apps.common import address_type
from . import common
# Determines the script type from a non-multisig address. # Determines the script type from a non-multisig address.
if coin.bech32_prefix and address.startswith(coin.bech32_prefix): if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
@ -34,7 +24,7 @@ def address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
elif witver == 1: elif witver == 1:
return InputScriptType.SPENDTAPROOT return InputScriptType.SPENDTAPROOT
else: else:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
if ( if (
not utils.BITCOIN_ONLY not utils.BITCOIN_ONLY
@ -46,7 +36,7 @@ def address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
try: try:
raw_address = base58.decode_check(address, coin.b58_hash) raw_address = base58.decode_check(address, coin.b58_hash)
except ValueError: except ValueError:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
if address_type.check(coin.address_type, raw_address): if address_type.check(coin.address_type, raw_address):
# p2pkh # p2pkh
@ -55,10 +45,28 @@ def address_to_script_type(address: str, coin: CoinInfo) -> InputScriptType:
# p2sh # p2sh
return InputScriptType.SPENDP2SHWITNESS return InputScriptType.SPENDP2SHWITNESS
raise wire.DataError("Invalid address") raise DataError("Invalid address")
async def verify_message(ctx: wire.Context, msg: VerifyMessage) -> Success: async def verify_message(ctx: Context, msg: VerifyMessage) -> Success:
from trezor import utils
from trezor.wire import ProcessError
from trezor.crypto.curve import secp256k1
from trezor.enums import InputScriptType
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common import coins
from apps.common.signverify import decode_message, message_digest
from .addresses import (
address_p2wpkh,
address_p2wpkh_in_p2sh,
address_pkh,
address_short,
address_to_cashaddr,
)
message = msg.message message = msg.message
address = msg.address address = msg.address
signature = msg.signature signature = msg.signature
@ -67,7 +75,7 @@ async def verify_message(ctx: wire.Context, msg: VerifyMessage) -> Success:
digest = message_digest(coin, message) digest = message_digest(coin, message)
script_type = address_to_script_type(address, coin) script_type = _address_to_script_type(address, coin)
recid = signature[0] recid = signature[0]
if 27 <= recid <= 34: if 27 <= recid <= 34:
# p2pkh or no script type provided # p2pkh or no script type provided
@ -79,12 +87,12 @@ async def verify_message(ctx: wire.Context, msg: VerifyMessage) -> Success:
# native segwit # native segwit
signature = bytes([signature[0] - 8]) + signature[1:] signature = bytes([signature[0] - 8]) + signature[1:]
else: else:
raise wire.ProcessError("Invalid signature") raise ProcessError("Invalid signature")
pubkey = secp256k1.verify_recover(signature, digest) pubkey = secp256k1.verify_recover(signature, digest)
if not pubkey: if not pubkey:
raise wire.ProcessError("Invalid signature") raise ProcessError("Invalid signature")
if script_type == InputScriptType.SPENDADDRESS: if script_type == InputScriptType.SPENDADDRESS:
addr = address_pkh(pubkey, coin) addr = address_pkh(pubkey, coin)
@ -95,16 +103,16 @@ async def verify_message(ctx: wire.Context, msg: VerifyMessage) -> Success:
elif script_type == InputScriptType.SPENDWITNESS: elif script_type == InputScriptType.SPENDWITNESS:
addr = address_p2wpkh(pubkey, coin) addr = address_p2wpkh(pubkey, coin)
else: else:
raise wire.ProcessError("Invalid signature") raise ProcessError("Invalid signature")
if addr != address: if addr != address:
raise wire.ProcessError("Invalid signature") raise ProcessError("Invalid signature")
await confirm_signverify( await confirm_signverify(
ctx, ctx,
coin.coin_shortcut, coin.coin_shortcut,
decode_message(message), decode_message(message),
address=address_short(coin, address), address_short(coin, address),
verify=True, verify=True,
) )

View File

@ -1,7 +1,6 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto.hashlib import sha256
from trezor.utils import ensure from trezor.utils import ensure
from apps.common.writers import ( # noqa: F401 from apps.common.writers import ( # noqa: F401
@ -72,22 +71,24 @@ def write_tx_output(w: Writer, o: TxOutput | PrevOutput, script_pubkey: bytes) -
def write_op_push(w: Writer, n: int) -> None: def write_op_push(w: Writer, n: int) -> None:
append = w.append # local_cache_attribute
ensure(0 <= n <= 0xFFFF_FFFF) ensure(0 <= n <= 0xFFFF_FFFF)
if n < 0x4C: if n < 0x4C:
w.append(n & 0xFF) append(n & 0xFF)
elif n < 0x100: elif n < 0x100:
w.append(0x4C) append(0x4C)
w.append(n & 0xFF) append(n & 0xFF)
elif n < 0x1_0000: elif n < 0x1_0000:
w.append(0x4D) append(0x4D)
w.append(n & 0xFF) append(n & 0xFF)
w.append((n >> 8) & 0xFF) append((n >> 8) & 0xFF)
else: else:
w.append(0x4E) append(0x4E)
w.append(n & 0xFF) append(n & 0xFF)
w.append((n >> 8) & 0xFF) append((n >> 8) & 0xFF)
w.append((n >> 16) & 0xFF) append((n >> 16) & 0xFF)
w.append((n >> 24) & 0xFF) append((n >> 24) & 0xFF)
def op_push_length(n: int) -> int: def op_push_length(n: int) -> int:
@ -103,6 +104,8 @@ def op_push_length(n: int) -> int:
def get_tx_hash(w: HashWriter, double: bool = False, reverse: bool = False) -> bytes: def get_tx_hash(w: HashWriter, double: bool = False, reverse: bool = False) -> bytes:
from trezor.crypto.hashlib import sha256
d = w.get_digest() d = w.get_digest()
if double: if double:
d = sha256(d).digest() d = sha256(d).digest()

View File

@ -1,11 +1,16 @@
from common import * from common import *
from trezor.crypto import bip32, bip39 from trezor.crypto import bip32, bip39
from trezor import wire
from trezor.messages import GetAddress from trezor.messages import GetAddress
from trezor.enums import InputScriptType
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import coins from apps.common import coins
from apps.bitcoin import scripts
from apps.bitcoin.addresses import * from apps.bitcoin.addresses import *
from apps.bitcoin.addresses import (
_address_p2wsh, _address_p2wsh_in_p2sh,
_address_multisig_p2wsh_in_p2sh, _address_multisig_p2sh
)
from apps.bitcoin.keychain import validate_path_against_script_type from apps.bitcoin.keychain import validate_path_against_script_type
from apps.bitcoin.writers import * from apps.bitcoin.writers import *
@ -73,7 +78,7 @@ class TestAddress(unittest.TestCase):
h = HashWriter(sha256()) h = HashWriter(sha256())
write_bytes_unchecked(h, script) write_bytes_unchecked(h, script)
address = address_p2wsh( address = _address_p2wsh(
h.get_digest(), h.get_digest(),
coin.bech32_prefix coin.bech32_prefix
) )
@ -83,7 +88,7 @@ class TestAddress(unittest.TestCase):
coin = coins.by_name('Bitcoin') coin = coins.by_name('Bitcoin')
# test data from Mastering Bitcoin # test data from Mastering Bitcoin
address = address_p2wsh_in_p2sh( address = _address_p2wsh_in_p2sh(
unhexlify('9592d601848d04b172905e0ddb0adde59f1590f1e553ffc81ddc4b0ed927dd73'), unhexlify('9592d601848d04b172905e0ddb0adde59f1590f1e553ffc81ddc4b0ed927dd73'),
coin coin
) )
@ -99,7 +104,7 @@ class TestAddress(unittest.TestCase):
# unhexlify('046ce31db9bdd543e72fe3039a1f1c047dab87037c36a669ff90e28da1848f640de68c2fe913d363a51154a0c62d7adea1b822d05035077418267b1a1379790187'), # unhexlify('046ce31db9bdd543e72fe3039a1f1c047dab87037c36a669ff90e28da1848f640de68c2fe913d363a51154a0c62d7adea1b822d05035077418267b1a1379790187'),
# unhexlify('0411ffd36c70776538d079fbae117dc38effafb33304af83ce4894589747aee1ef992f63280567f52f5ba870678b4ab4ff6c8ea600bd217870a8b4f1f09f3a8e83'), # unhexlify('0411ffd36c70776538d079fbae117dc38effafb33304af83ce4894589747aee1ef992f63280567f52f5ba870678b4ab4ff6c8ea600bd217870a8b4f1f09f3a8e83'),
# ] # ]
# address = address_multisig_p2sh(pubkeys, 2, coin.address_type_p2sh) # address = _address_multisig_p2sh(pubkeys, 2, coin.address_type_p2sh)
# self.assertEqual(address, '347N1Thc213QqfYCz3PZkjoJpNv5b14kBd') # self.assertEqual(address, '347N1Thc213QqfYCz3PZkjoJpNv5b14kBd')
coin = coins.by_name('Bitcoin') coin = coins.by_name('Bitcoin')
@ -107,12 +112,12 @@ class TestAddress(unittest.TestCase):
unhexlify('02fe6f0a5a297eb38c391581c4413e084773ea23954d93f7753db7dc0adc188b2f'), unhexlify('02fe6f0a5a297eb38c391581c4413e084773ea23954d93f7753db7dc0adc188b2f'),
unhexlify('02ff12471208c14bd580709cb2358d98975247d8765f92bc25eab3b2763ed605f8'), unhexlify('02ff12471208c14bd580709cb2358d98975247d8765f92bc25eab3b2763ed605f8'),
] ]
address = address_multisig_p2sh(pubkeys, 2, coin) address = _address_multisig_p2sh(pubkeys, 2, coin)
self.assertEqual(address, '39bgKC7RFbpoCRbtD5KEdkYKtNyhpsNa3Z') self.assertEqual(address, '39bgKC7RFbpoCRbtD5KEdkYKtNyhpsNa3Z')
for invalid_m in (-1, 0, len(pubkeys) + 1, 16): for invalid_m in (-1, 0, len(pubkeys) + 1, 16):
with self.assertRaises(wire.DataError): with self.assertRaises(wire.DataError):
address_multisig_p2sh(pubkeys, invalid_m, coin) _address_multisig_p2sh(pubkeys, invalid_m, coin)
def test_multisig_address_p2wsh_in_p2sh(self): def test_multisig_address_p2wsh_in_p2sh(self):
# test data from # test data from
@ -123,7 +128,7 @@ class TestAddress(unittest.TestCase):
unhexlify('0320ce424c6d61f352ccfea60d209651672cfb03b2dc77d1d64d3ba519aec756ae'), unhexlify('0320ce424c6d61f352ccfea60d209651672cfb03b2dc77d1d64d3ba519aec756ae'),
] ]
address = address_multisig_p2wsh_in_p2sh(pubkeys, 2, coin) address = _address_multisig_p2wsh_in_p2sh(pubkeys, 2, coin)
self.assertEqual(address, '2MsZ2fpGKUydzY62v6trPHR8eCx5JTy1Dpa') self.assertEqual(address, '2MsZ2fpGKUydzY62v6trPHR8eCx5JTy1Dpa')
# def test_multisig_address_p2wsh(self): # def test_multisig_address_p2wsh(self):

View File

@ -4,7 +4,7 @@ from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from apps.bitcoin.keychain import get_coin_by_name, get_keychain_for_coin from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
@ -14,8 +14,8 @@ class TestBitcoinKeychain(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
def test_bitcoin(self): def test_bitcoin(self):
coin = get_coin_by_name("Bitcoin") coin = _get_coin_by_name("Bitcoin")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) keychain = await_result(_get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bitcoin") self.assertEqual(coin.coin_name, "Bitcoin")
valid_addresses = ( valid_addresses = (
@ -45,8 +45,8 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_testnet(self): def test_testnet(self):
coin = get_coin_by_name("Testnet") coin = _get_coin_by_name("Testnet")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) keychain = await_result(_get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Testnet") self.assertEqual(coin.coin_name, "Testnet")
valid_addresses = ( valid_addresses = (
@ -76,14 +76,14 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_unspecified(self): def test_unspecified(self):
coin = get_coin_by_name(None) coin = _get_coin_by_name(None)
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) keychain = await_result(_get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bitcoin") self.assertEqual(coin.coin_name, "Bitcoin")
keychain.derive([H_(44), H_(0), H_(0), 0, 0]) keychain.derive([H_(44), H_(0), H_(0), 0, 0])
def test_unknown(self): def test_unknown(self):
with self.assertRaises(wire.DataError): with self.assertRaises(wire.DataError):
get_coin_by_name("MadeUpCoin2020") _get_coin_by_name("MadeUpCoin2020")
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
@ -94,8 +94,8 @@ class TestAltcoinKeychains(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
def test_bcash(self): def test_bcash(self):
coin = get_coin_by_name("Bcash") coin = _get_coin_by_name("Bcash")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) keychain = await_result(_get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bcash") self.assertEqual(coin.coin_name, "Bcash")
self.assertFalse(coin.segwit) self.assertFalse(coin.segwit)
@ -131,8 +131,8 @@ class TestAltcoinKeychains(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_litecoin(self): def test_litecoin(self):
coin = get_coin_by_name("Litecoin") coin = _get_coin_by_name("Litecoin")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) keychain = await_result(_get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Litecoin") self.assertEqual(coin.coin_name, "Litecoin")
self.assertTrue(coin.segwit) self.assertTrue(coin.segwit)

View File

@ -8,7 +8,7 @@ from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.common.paths import HARDENED, AlwaysMatchingSchema from apps.common.paths import HARDENED, AlwaysMatchingSchema
from apps.bitcoin import ownership, scripts from apps.bitcoin import ownership, scripts
from apps.bitcoin.addresses import address_p2tr, address_p2wpkh, address_p2wpkh_in_p2sh, address_multisig_p2wsh, address_multisig_p2wsh_in_p2sh, address_multisig_p2sh from apps.bitcoin.addresses import _address_p2tr, address_p2wpkh, address_p2wpkh_in_p2sh, _address_multisig_p2wsh, _address_multisig_p2wsh_in_p2sh, _address_multisig_p2sh
from apps.bitcoin.multisig import multisig_get_pubkeys from apps.bitcoin.multisig import multisig_get_pubkeys
@ -77,7 +77,7 @@ class TestOwnershipProof(unittest.TestCase):
commitment_data = b"" commitment_data = b""
node = keychain.derive([86 | HARDENED, 0 | HARDENED, 0 | HARDENED, 1, 0]) node = keychain.derive([86 | HARDENED, 0 | HARDENED, 0 | HARDENED, 1, 0])
address = address_p2tr(node.public_key(), coin) address = _address_p2tr(node.public_key(), coin)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_id = ownership.get_identifier(script_pubkey, keychain) ownership_id = ownership.get_identifier(script_pubkey, keychain)
self.assertEqual(ownership_id, unhexlify("dc18066224b9e30e306303436dc18ab881c7266c13790350a3fe415e438135ec")) self.assertEqual(ownership_id, unhexlify("dc18066224b9e30e306303436dc18ab881c7266c13790350a3fe415e438135ec"))
@ -177,7 +177,7 @@ class TestOwnershipProof(unittest.TestCase):
) )
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2wsh(pubkeys, multisig.m, coin.bech32_prefix) address = _address_multisig_p2wsh(pubkeys, multisig.m, coin.bech32_prefix)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_ids = [ownership.get_identifier(script_pubkey, keychain) for keychain in keychains] ownership_ids = [ownership.get_identifier(script_pubkey, keychain) for keychain in keychains]
self.assertEqual(ownership_ids[0], unhexlify("309c4ffec5c228cc836b51d572c0a730dbabd39df9f01862502ac9eabcdeb94a")) self.assertEqual(ownership_ids[0], unhexlify("309c4ffec5c228cc836b51d572c0a730dbabd39df9f01862502ac9eabcdeb94a"))
@ -238,7 +238,7 @@ class TestOwnershipProof(unittest.TestCase):
) )
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2wsh_in_p2sh(pubkeys, multisig.m, coin) address = _address_multisig_p2wsh_in_p2sh(pubkeys, multisig.m, coin)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_id = ownership.get_identifier(script_pubkey, keychain) ownership_id = ownership.get_identifier(script_pubkey, keychain)
ownership_ids = [b'\x00' * 32, b'\x01' * 32, b'\x02' * 32, ownership_id] ownership_ids = [b'\x00' * 32, b'\x01' * 32, b'\x02' * 32, ownership_id]
@ -312,7 +312,7 @@ class TestOwnershipProof(unittest.TestCase):
) )
pubkeys = multisig_get_pubkeys(multisig) pubkeys = multisig_get_pubkeys(multisig)
address = address_multisig_p2sh(pubkeys, multisig.m, coin) address = _address_multisig_p2sh(pubkeys, multisig.m, coin)
script_pubkey = scripts.output_derive_script(address, coin) script_pubkey = scripts.output_derive_script(address, coin)
ownership_id = ownership.get_identifier(script_pubkey, keychain) ownership_id = ownership.get_identifier(script_pubkey, keychain)
ownership_ids = [b'\x00' * 32, ownership_id] ownership_ids = [b'\x00' * 32, ownership_id]

View File

@ -2,7 +2,7 @@ from common import *
from apps.bitcoin.common import SigHashType from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.scripts import output_derive_script
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher from apps.bitcoin.sign_tx.sig_hasher import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain

View File

@ -2,7 +2,7 @@ from common import *
from apps.bitcoin.common import SigHashType from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.scripts import output_derive_script
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher from apps.bitcoin.sign_tx.sig_hasher import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain

View File

@ -1,15 +1,12 @@
from common import * from common import *
from apps.bitcoin.common import SigHashType from apps.bitcoin.common import SigHashType
from apps.bitcoin.scripts import output_derive_script from apps.bitcoin.sign_tx.sig_hasher import BitcoinSigHasher
from apps.bitcoin.sign_tx.bitcoin import BitcoinSigHasher
from apps.bitcoin.writers import get_tx_hash from apps.bitcoin.writers import get_tx_hash
from trezor.messages import SignTx from trezor.messages import SignTx
from trezor.messages import TxInput from trezor.messages import TxInput
from trezor.messages import TxOutput
from trezor.messages import PrevOutput from trezor.messages import PrevOutput
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.enums import OutputScriptType
VECTORS = [ VECTORS = [

View File

@ -28,7 +28,7 @@ from trezor import wire
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import helpers, bitcoin from apps.bitcoin.sign_tx import helpers, bitcoin
@ -160,7 +160,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
@ -292,7 +292,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):
@ -352,7 +352,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
None None
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):

View File

@ -27,7 +27,7 @@ from trezor.enums import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import bitcoinlike, helpers from apps.bitcoin.sign_tx import bitcoinlike, helpers
@ -161,7 +161,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer() signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):
@ -293,7 +293,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer() signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):

View File

@ -28,7 +28,7 @@ from trezor import wire
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import bitcoin, helpers from apps.bitcoin.sign_tx import bitcoin, helpers
@ -157,7 +157,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):
@ -296,7 +296,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):
@ -405,7 +405,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin, None).signer()
i = 0 i = 0

View File

@ -27,7 +27,7 @@ from trezor.enums import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import bitcoinlike, helpers from apps.bitcoin.sign_tx import bitcoinlike, helpers
@ -161,7 +161,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer() signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):
@ -300,7 +300,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer() signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):

View File

@ -26,7 +26,7 @@ from trezor.enums import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import bitcoin, helpers from apps.bitcoin.sign_tx import bitcoin, helpers
@ -212,7 +212,7 @@ class TestSignTx(unittest.TestCase):
" ".join(["all"] * 12), " ".join(["all"] * 12),
"", "",
) )
ns = get_schemas_for_coin(coin_bitcoin) ns = _get_schemas_for_coin(coin_bitcoin)
keychain = Keychain(seed, coin_bitcoin.curve_name, ns) keychain = Keychain(seed, coin_bitcoin.curve_name, ns)
signer = bitcoin.Bitcoin(tx, keychain, coin_bitcoin, None).signer() signer = bitcoin.Bitcoin(tx, keychain, coin_bitcoin, None).signer()

View File

@ -26,7 +26,7 @@ from trezor.enums import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import decred, helpers from apps.bitcoin.sign_tx import decred, helpers
@ -194,7 +194,7 @@ class TestSignTxDecred(unittest.TestCase):
" ".join(["all"] * 12), " ".join(["all"] * 12),
"", "",
) )
ns = get_schemas_for_coin(coin_decred) ns = _get_schemas_for_coin(coin_decred)
keychain = Keychain(seed, coin_decred.curve_name, ns) keychain = Keychain(seed, coin_decred.curve_name, ns)
signer = decred.Decred(tx, keychain, coin_decred, None).signer() signer = decred.Decred(tx, keychain, coin_decred, None).signer()
@ -376,7 +376,7 @@ class TestSignTxDecred(unittest.TestCase):
" ".join(["all"] * 12), " ".join(["all"] * 12),
"", "",
) )
ns = get_schemas_for_coin(coin_decred) ns = _get_schemas_for_coin(coin_decred)
keychain = Keychain(seed, coin_decred.curve_name, ns) keychain = Keychain(seed, coin_decred.curve_name, ns)
signer = decred.Decred(tx, keychain, coin_decred, None).signer() signer = decred.Decred(tx, keychain, coin_decred, None).signer()

View File

@ -26,7 +26,7 @@ from trezor.enums import OutputScriptType
from apps.common import coins from apps.common import coins
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from apps.bitcoin.keychain import get_schemas_for_coin from apps.bitcoin.keychain import _get_schemas_for_coin
from apps.bitcoin.sign_tx import bitcoinlike, helpers from apps.bitcoin.sign_tx import bitcoinlike, helpers
@ -104,7 +104,7 @@ class TestSignTx_GRS(unittest.TestCase):
] ]
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
ns = get_schemas_for_coin(coin) ns = _get_schemas_for_coin(coin)
keychain = Keychain(seed, coin.curve_name, ns) keychain = Keychain(seed, coin.curve_name, ns)
signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer() signer = bitcoinlike.Bitcoinlike(tx, keychain, coin, None).signer()
for request, expected_response in chunks(messages, 2): for request, expected_response in chunks(messages, 2):