From d33244912b08ec9273449a7e25fff7dbcac26c44 Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Mon, 17 Feb 2025 10:35:55 +0200 Subject: [PATCH] refactor(core): separate base and priority Solana fees Move fee calculation into `Transaction` class. Also, replace floating-point division by integer division in fee calculation. [no changelog] --- core/src/apps/solana/constants.py | 3 + core/src/apps/solana/layout.py | 13 ++-- .../src/apps/solana/predefined_transaction.py | 6 +- core/src/apps/solana/sign_tx.py | 47 +------------- core/src/apps/solana/transaction/__init__.py | 61 ++++++++++++++++++- 5 files changed, 76 insertions(+), 54 deletions(-) diff --git a/core/src/apps/solana/constants.py b/core/src/apps/solana/constants.py index 79d6115b5c..f74d80fd54 100644 --- a/core/src/apps/solana/constants.py +++ b/core/src/apps/solana/constants.py @@ -4,3 +4,6 @@ ADDRESS_SIZE = const(32) SOLANA_BASE_FEE_LAMPORTS = const(5000) SOLANA_COMPUTE_UNIT_LIMIT = const(200000) + +# 1 lamport has 1M microlamports +MICROLAMPORTS_PER_LAMPORT = const(1000000) diff --git a/core/src/apps/solana/layout.py b/core/src/apps/solana/layout.py index c03c0c8c96..e4d44b09e4 100644 --- a/core/src/apps/solana/layout.py +++ b/core/src/apps/solana/layout.py @@ -18,6 +18,7 @@ from .types import AddressType if TYPE_CHECKING: from typing import Sequence + from .transaction import Fee from .transaction.instructions import Instruction, SystemProgramTransferInstruction from .types import AddressReference @@ -264,7 +265,7 @@ async def confirm_unsupported_program_confirm( async def confirm_system_transfer( transfer_instruction: SystemProgramTransferInstruction, - fee: int, + fee: Fee, signer_path: list[int], blockhash: bytes, ) -> None: @@ -293,7 +294,7 @@ async def confirm_token_transfer( token_mint: bytes, amount: int, decimals: int, - fee: int, + fee: Fee, signer_path: list[int], blockhash: bytes, ) -> None: @@ -334,13 +335,13 @@ async def confirm_custom_transaction( amount: int, decimals: int, unit: str, - fee: int, + fee: Fee, signer_path: list[int], blockhash: bytes, ) -> None: await confirm_solana_tx( amount=f"{format_amount(amount, decimals)} {unit}", - fee=f"{format_amount(fee, 9)} SOL", + fee=f"{format_amount(fee.total, 9)} SOL", fee_title=f"{TR.solana__expected_fee}:", items=( (f"{TR.words__account}:", _format_path(signer_path)), @@ -350,12 +351,12 @@ async def confirm_custom_transaction( async def confirm_transaction( - signer_path: list[int], blockhash: bytes, fee: int + signer_path: list[int], blockhash: bytes, fee: Fee ) -> None: await confirm_solana_tx( amount="", amount_title="", - fee=f"{format_amount(fee, 9)} SOL", + fee=f"{format_amount(fee.total, 9)} SOL", fee_title=f"{TR.solana__expected_fee}:", items=( (f"{TR.words__account}:", _format_path(signer_path)), diff --git a/core/src/apps/solana/predefined_transaction.py b/core/src/apps/solana/predefined_transaction.py index e2f6edf5e0..f3d7bf595f 100644 --- a/core/src/apps/solana/predefined_transaction.py +++ b/core/src/apps/solana/predefined_transaction.py @@ -13,6 +13,8 @@ from .transaction.instructions import ( if TYPE_CHECKING: from trezor.messages import SolanaTxAdditionalInfo + from .transaction import Fee + TransferTokenInstruction = ( TokenProgramTransferCheckedInstruction | Token2022ProgramTransferCheckedInstruction @@ -114,7 +116,7 @@ def is_predefined_token_transfer( async def try_confirm_token_transfer_transaction( transaction: Transaction, - fee: int, + fee: Fee, signer_path: list[int], blockhash: bytes, additional_info: SolanaTxAdditionalInfo | None = None, @@ -169,7 +171,7 @@ async def try_confirm_token_transfer_transaction( async def try_confirm_predefined_transaction( transaction: Transaction, - fee: int, + fee: Fee, signer_path: list[int], blockhash: bytes, additional_info: SolanaTxAdditionalInfo | None = None, diff --git a/core/src/apps/solana/sign_tx.py b/core/src/apps/solana/sign_tx.py index 828f4b9561..39d9ed34ac 100644 --- a/core/src/apps/solana/sign_tx.py +++ b/core/src/apps/solana/sign_tx.py @@ -56,7 +56,7 @@ async def sign_tx( br_code=ButtonRequestType.Other, ) - fee = calculate_fee(transaction) + fee = transaction.calculate_fee() if not await try_confirm_predefined_transaction( transaction, fee, address_n, transaction.blockhash, msg.additional_info @@ -65,7 +65,7 @@ async def sign_tx( await confirm_transaction( address_n, transaction.blockhash, - calculate_fee(transaction), + fee, ) signature = ed25519.sign(node.private_key(), serialized_tx) @@ -110,46 +110,3 @@ async def confirm_instructions( signer_path, signer_public_key, ) - - -def calculate_fee(transaction: Transaction) -> int: - import math - - from .constants import SOLANA_BASE_FEE_LAMPORTS, SOLANA_COMPUTE_UNIT_LIMIT - from .transaction.instructions import ( - COMPUTE_BUDGET_PROGRAM_ID, - COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT, - COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE, - ) - from .types import AddressType - - number_of_signers = 0 - for address in transaction.addresses: - if address[1] == AddressType.AddressSig: - number_of_signers += 1 - - base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers - - unit_price = 0 - is_unit_price_set = False - unit_limit = SOLANA_COMPUTE_UNIT_LIMIT - is_unit_limit_set = False - - for instruction in transaction.instructions: - if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID: - if ( - instruction.instruction_id - == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT - and not is_unit_limit_set - ): - unit_limit = instruction.units - is_unit_limit_set = True - elif ( - instruction.instruction_id - == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE - and not is_unit_price_set - ): - unit_price = instruction.lamports - is_unit_price_set = True - - return int(base_fee + math.ceil(unit_price * unit_limit / 1000000)) diff --git a/core/src/apps/solana/transaction/__init__.py b/core/src/apps/solana/transaction/__init__.py index 21e2a1cde1..2af18cfdd6 100644 --- a/core/src/apps/solana/transaction/__init__.py +++ b/core/src/apps/solana/transaction/__init__.py @@ -4,15 +4,37 @@ from trezor.crypto import base58 from trezor.utils import BufferReader from trezor.wire import DataError +from ..constants import ( + MICROLAMPORTS_PER_LAMPORT, + SOLANA_BASE_FEE_LAMPORTS, + SOLANA_COMPUTE_UNIT_LIMIT, +) from ..types import AddressType from .instruction import Instruction -from .instructions import get_instruction, get_instruction_id_length +from .instructions import ( + COMPUTE_BUDGET_PROGRAM_ID, + COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT, + COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE, + get_instruction, + get_instruction_id_length, +) from .parse import parse_block_hash, parse_pubkey, parse_var_int if TYPE_CHECKING: from ..types import Account, Address, AddressReference, RawInstruction +class Fee: + def __init__( + self, + base: int, + priority: int, + ) -> None: + self.base = base + self.priority = priority + self.total = base + priority + + class Transaction: blind_signing = False required_signers_count = 0 @@ -209,3 +231,40 @@ class Transaction: for instruction in self.instructions if not instruction.is_ui_hidden ] + + def calculate_fee(self) -> Fee: + number_of_signers = 0 + for address in self.addresses: + if address[1] == AddressType.AddressSig: + number_of_signers += 1 + + base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers + + unit_price = 0 + is_unit_price_set = False + unit_limit = SOLANA_COMPUTE_UNIT_LIMIT + is_unit_limit_set = False + + for instruction in self.instructions: + if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID: + if ( + instruction.instruction_id + == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT + and not is_unit_limit_set + ): + unit_limit = instruction.units + is_unit_limit_set = True + elif ( + instruction.instruction_id + == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE + and not is_unit_price_set + ): + unit_price = instruction.lamports + is_unit_price_set = True + + priority_fee = unit_price * unit_limit # in microlamports + return Fee( + base=base_fee, + priority=(priority_fee + MICROLAMPORTS_PER_LAMPORT - 1) + // MICROLAMPORTS_PER_LAMPORT, + )