1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-24 02:59:03 +00:00

refactor(core): Move PaymentRequestVerifier to common.

This commit is contained in:
Andrew Kozlik 2025-04-08 15:40:29 +02:00
parent 9c3592e8d0
commit 79bb9b96f7
4 changed files with 33 additions and 36 deletions

View File

@ -18,10 +18,10 @@ if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
from apps.common.payment_request import PaymentRequestVerifier
from ..authorization import CoinJoinAuthorization
from .bitcoin import Bitcoin
from .payment_request import PaymentRequestVerifier
from .tx_info import TxInfo
@ -91,10 +91,12 @@ class Approver:
async def add_payment_request(
self, msg: PaymentRequest, keychain: Keychain
) -> None:
from .payment_request import PaymentRequestVerifier
from apps.common.payment_request import PaymentRequestVerifier
self.finish_payment_request()
self.payment_req_verifier = PaymentRequestVerifier(msg, self.coin, keychain)
self.payment_req_verifier = PaymentRequestVerifier(
msg, self.coin.slip44, keychain
)
def finish_payment_request(self) -> None:
if self.payment_req_verifier:
@ -106,7 +108,9 @@ class Approver:
await self._add_output(txo, script_pubkey)
self.change_out += txo.amount
if self.payment_req_verifier:
self.payment_req_verifier.add_change_output(txo)
# txo.address filled in by output_derive_script().
assert txo.address is not None
self.payment_req_verifier.add_output(txo.amount, txo.address, change=True)
def add_orig_change_output(self, txo: TxOutput) -> None:
self.orig_total_out += txo.amount
@ -121,7 +125,9 @@ class Approver:
) -> None:
await self._add_output(txo, script_pubkey)
if self.payment_req_verifier:
self.payment_req_verifier.add_external_output(txo)
# External outputs have txo.address filled by definition.
assert txo.address is not None
self.payment_req_verifier.add_output(txo.amount, txo.address)
def add_orig_external_output(self, txo: TxOutput) -> None:
self.orig_total_out += txo.amount

View File

@ -5,6 +5,7 @@ from trezor.utils import ensure
from apps.common.writers import ( # noqa: F401
write_bytes_fixed,
write_bytes_prefixed,
write_bytes_reversed,
write_bytes_unchecked,
write_compact_size,
@ -29,11 +30,6 @@ write_uint64 = write_uint64_le
TX_HASH_SIZE = const(32)
def write_bytes_prefixed(w: Writer, b: bytes) -> None:
write_compact_size(w, len(b))
write_bytes_unchecked(w, b)
def write_tx_input(w: Writer, i: TxInput | PrevInput, script: bytes) -> None:
write_bytes_reversed(w, i.prev_hash, TX_HASH_SIZE)
write_uint32(w, i.prev_index)

View File

@ -3,12 +3,11 @@ from typing import TYPE_CHECKING
from trezor.wire import DataError, context
from .. import writers
from . import writers
if TYPE_CHECKING:
from trezor.messages import PaymentRequest, TxOutput
from trezor.messages import PaymentRequest
from apps.common import coininfo
from apps.common.keychain import Keychain
_MEMO_TYPE_TEXT = const(1)
@ -23,16 +22,14 @@ class PaymentRequestVerifier:
else:
PUBLIC_KEY = b""
def __init__(
self, msg: PaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None:
def __init__(self, msg: PaymentRequest, slip44_id: int, keychain: Keychain) -> None:
from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from apps.common.address_mac import check_address_mac
from .. import writers # pylint: disable=import-outside-toplevel
from . import writers # pylint: disable=import-outside-toplevel
self.h_outputs = HashWriter(sha256())
self.amount = 0
@ -57,23 +54,23 @@ class PaymentRequestVerifier:
for m in msg.memos:
if m.text_memo is not None:
memo = m.text_memo
writers.write_uint32(self.h_pr, _MEMO_TYPE_TEXT)
writers.write_uint32_le(self.h_pr, _MEMO_TYPE_TEXT)
writers.write_bytes_prefixed(self.h_pr, memo.text.encode())
elif m.refund_memo is not None:
memo = m.refund_memo
# Unlike in a coin purchase memo, the coin type is implied by the payment request.
check_address_mac(memo.address, memo.mac, coin.slip44, keychain)
writers.write_uint32(self.h_pr, _MEMO_TYPE_REFUND)
check_address_mac(memo.address, memo.mac, slip44_id, keychain)
writers.write_uint32_le(self.h_pr, _MEMO_TYPE_REFUND)
writers.write_bytes_prefixed(self.h_pr, memo.address.encode())
elif m.coin_purchase_memo is not None:
memo = m.coin_purchase_memo
check_address_mac(memo.address, memo.mac, memo.coin_type, keychain)
writers.write_uint32(self.h_pr, _MEMO_TYPE_COIN_PURCHASE)
writers.write_uint32(self.h_pr, memo.coin_type)
writers.write_uint32_le(self.h_pr, _MEMO_TYPE_COIN_PURCHASE)
writers.write_uint32_le(self.h_pr, memo.coin_type)
writers.write_bytes_prefixed(self.h_pr, memo.amount.encode())
writers.write_bytes_prefixed(self.h_pr, memo.address.encode())
writers.write_uint32(self.h_pr, coin.slip44)
writers.write_uint32_le(self.h_pr, slip44_id)
def verify(self) -> None:
from trezor.crypto.curve import secp256k1
@ -81,7 +78,7 @@ class PaymentRequestVerifier:
if self.expected_amount is not None and self.amount != self.expected_amount:
raise DataError("Invalid amount in payment request.")
hash_outputs = writers.get_tx_hash(self.h_outputs)
hash_outputs = self.h_outputs.get_digest()
writers.write_bytes_fixed(self.h_pr, hash_outputs, 32)
if not secp256k1.verify(
@ -89,15 +86,8 @@ class PaymentRequestVerifier:
):
raise DataError("Invalid signature in payment request.")
def _add_output(self, txo: TxOutput) -> None:
# For change outputs txo.address filled in by output_derive_script().
assert txo.address is not None
writers.write_uint64(self.h_outputs, txo.amount)
writers.write_bytes_prefixed(self.h_outputs, txo.address.encode())
def add_external_output(self, txo: TxOutput) -> None:
self._add_output(txo)
self.amount += txo.amount
def add_change_output(self, txo: TxOutput) -> None:
self._add_output(txo)
def add_output(self, amount: int, address: str, change: bool = False) -> None:
writers.write_uint64_le(self.h_outputs, amount)
writers.write_bytes_prefixed(self.h_outputs, address.encode())
if not change:
self.amount += amount

View File

@ -51,6 +51,11 @@ def write_bytes_fixed(w: Writer, b: bytes, length: int) -> int:
return length
def write_bytes_prefixed(w: Writer, b: bytes) -> None:
write_compact_size(w, len(b))
write_bytes_unchecked(w, b)
def write_bytes_reversed(w: Writer, b: bytes, length: int) -> int:
ensure(len(b) == length)
w.extend(bytes(reversed(b)))