1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-14 14:27:25 +00:00

refactor(core): separate base and priority Solana fees

Move fee calculation into `Transaction` class.

Also, replace floating-point division by integer division in fee calculation.

[no changelog]
This commit is contained in:
Roman Zeyde 2025-02-17 10:35:55 +02:00 committed by Roman Zeyde
parent c55893ffe5
commit d33244912b
5 changed files with 76 additions and 54 deletions

View File

@ -4,3 +4,6 @@ ADDRESS_SIZE = const(32)
SOLANA_BASE_FEE_LAMPORTS = const(5000)
SOLANA_COMPUTE_UNIT_LIMIT = const(200000)
# 1 lamport has 1M microlamports
MICROLAMPORTS_PER_LAMPORT = const(1000000)

View File

@ -18,6 +18,7 @@ from .types import AddressType
if TYPE_CHECKING:
from typing import Sequence
from .transaction import Fee
from .transaction.instructions import Instruction, SystemProgramTransferInstruction
from .types import AddressReference
@ -264,7 +265,7 @@ async def confirm_unsupported_program_confirm(
async def confirm_system_transfer(
transfer_instruction: SystemProgramTransferInstruction,
fee: int,
fee: Fee,
signer_path: list[int],
blockhash: bytes,
) -> None:
@ -293,7 +294,7 @@ async def confirm_token_transfer(
token_mint: bytes,
amount: int,
decimals: int,
fee: int,
fee: Fee,
signer_path: list[int],
blockhash: bytes,
) -> None:
@ -334,13 +335,13 @@ async def confirm_custom_transaction(
amount: int,
decimals: int,
unit: str,
fee: int,
fee: Fee,
signer_path: list[int],
blockhash: bytes,
) -> None:
await confirm_solana_tx(
amount=f"{format_amount(amount, decimals)} {unit}",
fee=f"{format_amount(fee, 9)} SOL",
fee=f"{format_amount(fee.total, 9)} SOL",
fee_title=f"{TR.solana__expected_fee}:",
items=(
(f"{TR.words__account}:", _format_path(signer_path)),
@ -350,12 +351,12 @@ async def confirm_custom_transaction(
async def confirm_transaction(
signer_path: list[int], blockhash: bytes, fee: int
signer_path: list[int], blockhash: bytes, fee: Fee
) -> None:
await confirm_solana_tx(
amount="",
amount_title="",
fee=f"{format_amount(fee, 9)} SOL",
fee=f"{format_amount(fee.total, 9)} SOL",
fee_title=f"{TR.solana__expected_fee}:",
items=(
(f"{TR.words__account}:", _format_path(signer_path)),

View File

@ -13,6 +13,8 @@ from .transaction.instructions import (
if TYPE_CHECKING:
from trezor.messages import SolanaTxAdditionalInfo
from .transaction import Fee
TransferTokenInstruction = (
TokenProgramTransferCheckedInstruction
| Token2022ProgramTransferCheckedInstruction
@ -114,7 +116,7 @@ def is_predefined_token_transfer(
async def try_confirm_token_transfer_transaction(
transaction: Transaction,
fee: int,
fee: Fee,
signer_path: list[int],
blockhash: bytes,
additional_info: SolanaTxAdditionalInfo | None = None,
@ -169,7 +171,7 @@ async def try_confirm_token_transfer_transaction(
async def try_confirm_predefined_transaction(
transaction: Transaction,
fee: int,
fee: Fee,
signer_path: list[int],
blockhash: bytes,
additional_info: SolanaTxAdditionalInfo | None = None,

View File

@ -56,7 +56,7 @@ async def sign_tx(
br_code=ButtonRequestType.Other,
)
fee = calculate_fee(transaction)
fee = transaction.calculate_fee()
if not await try_confirm_predefined_transaction(
transaction, fee, address_n, transaction.blockhash, msg.additional_info
@ -65,7 +65,7 @@ async def sign_tx(
await confirm_transaction(
address_n,
transaction.blockhash,
calculate_fee(transaction),
fee,
)
signature = ed25519.sign(node.private_key(), serialized_tx)
@ -110,46 +110,3 @@ async def confirm_instructions(
signer_path,
signer_public_key,
)
def calculate_fee(transaction: Transaction) -> int:
import math
from .constants import SOLANA_BASE_FEE_LAMPORTS, SOLANA_COMPUTE_UNIT_LIMIT
from .transaction.instructions import (
COMPUTE_BUDGET_PROGRAM_ID,
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT,
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE,
)
from .types import AddressType
number_of_signers = 0
for address in transaction.addresses:
if address[1] == AddressType.AddressSig:
number_of_signers += 1
base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers
unit_price = 0
is_unit_price_set = False
unit_limit = SOLANA_COMPUTE_UNIT_LIMIT
is_unit_limit_set = False
for instruction in transaction.instructions:
if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID:
if (
instruction.instruction_id
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT
and not is_unit_limit_set
):
unit_limit = instruction.units
is_unit_limit_set = True
elif (
instruction.instruction_id
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE
and not is_unit_price_set
):
unit_price = instruction.lamports
is_unit_price_set = True
return int(base_fee + math.ceil(unit_price * unit_limit / 1000000))

View File

@ -4,15 +4,37 @@ from trezor.crypto import base58
from trezor.utils import BufferReader
from trezor.wire import DataError
from ..constants import (
MICROLAMPORTS_PER_LAMPORT,
SOLANA_BASE_FEE_LAMPORTS,
SOLANA_COMPUTE_UNIT_LIMIT,
)
from ..types import AddressType
from .instruction import Instruction
from .instructions import get_instruction, get_instruction_id_length
from .instructions import (
COMPUTE_BUDGET_PROGRAM_ID,
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT,
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE,
get_instruction,
get_instruction_id_length,
)
from .parse import parse_block_hash, parse_pubkey, parse_var_int
if TYPE_CHECKING:
from ..types import Account, Address, AddressReference, RawInstruction
class Fee:
def __init__(
self,
base: int,
priority: int,
) -> None:
self.base = base
self.priority = priority
self.total = base + priority
class Transaction:
blind_signing = False
required_signers_count = 0
@ -209,3 +231,40 @@ class Transaction:
for instruction in self.instructions
if not instruction.is_ui_hidden
]
def calculate_fee(self) -> Fee:
number_of_signers = 0
for address in self.addresses:
if address[1] == AddressType.AddressSig:
number_of_signers += 1
base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers
unit_price = 0
is_unit_price_set = False
unit_limit = SOLANA_COMPUTE_UNIT_LIMIT
is_unit_limit_set = False
for instruction in self.instructions:
if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID:
if (
instruction.instruction_id
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT
and not is_unit_limit_set
):
unit_limit = instruction.units
is_unit_limit_set = True
elif (
instruction.instruction_id
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE
and not is_unit_price_set
):
unit_price = instruction.lamports
is_unit_price_set = True
priority_fee = unit_price * unit_limit # in microlamports
return Fee(
base=base_fee,
priority=(priority_fee + MICROLAMPORTS_PER_LAMPORT - 1)
// MICROLAMPORTS_PER_LAMPORT,
)