mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-18 04:18:10 +00:00
feat(solana): add sign_tx
implementation
This commit is contained in:
parent
081995788e
commit
68c0e6c43e
@ -683,14 +683,28 @@ if not utils.BITCOIN_ONLY:
|
||||
import apps.ripple.sign_tx
|
||||
apps.solana
|
||||
import apps.solana
|
||||
apps.solana.constants
|
||||
import apps.solana.constants
|
||||
apps.solana.format
|
||||
import apps.solana.format
|
||||
apps.solana.get_address
|
||||
import apps.solana.get_address
|
||||
apps.solana.get_public_key
|
||||
import apps.solana.get_public_key
|
||||
apps.solana.layout
|
||||
import apps.solana.layout
|
||||
apps.solana.sign_tx
|
||||
import apps.solana.sign_tx
|
||||
apps.solana.transaction
|
||||
import apps.solana.transaction
|
||||
apps.solana.transaction.instruction
|
||||
import apps.solana.transaction.instruction
|
||||
apps.solana.transaction.instructions
|
||||
import apps.solana.transaction.instructions
|
||||
apps.solana.transaction.parse
|
||||
import apps.solana.transaction.parse
|
||||
apps.solana.types
|
||||
import apps.solana.types
|
||||
apps.stellar
|
||||
import apps.stellar
|
||||
apps.stellar.consts
|
||||
|
6
core/src/apps/solana/constants.py
Normal file
6
core/src/apps/solana/constants.py
Normal file
@ -0,0 +1,6 @@
|
||||
from micropython import const
|
||||
|
||||
ADDRESS_SIZE = const(32)
|
||||
|
||||
SOLANA_BASE_FEE_LAMPORTS = const(5000)
|
||||
SOLANA_COMPUTE_UNIT_LIMIT = const(200000)
|
39
core/src/apps/solana/format.py
Normal file
39
core/src/apps/solana/format.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.strings import format_amount, format_timestamp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transaction.instructions import Instruction
|
||||
|
||||
|
||||
def format_pubkey(_: Instruction, value: bytes | None) -> str:
|
||||
from trezor.crypto import base58
|
||||
|
||||
if value is None:
|
||||
raise ValueError # should not be called with optional pubkey
|
||||
|
||||
return base58.encode(value)
|
||||
|
||||
|
||||
def format_lamports(_: Instruction, value: int) -> str:
|
||||
formatted = format_amount(value, decimals=9)
|
||||
return f"{formatted} SOL"
|
||||
|
||||
|
||||
def format_token_amount(instruction: Instruction, value: int) -> str:
|
||||
assert hasattr(instruction, "decimals") # enforced in instructions.py.mako
|
||||
|
||||
formatted = format_amount(value, decimals=instruction.decimals)
|
||||
return f"{formatted}"
|
||||
|
||||
|
||||
def format_unix_timestamp(_: Instruction, value: int) -> str:
|
||||
return format_timestamp(value)
|
||||
|
||||
|
||||
def format_int(_: Instruction, value: int) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
def format_identity(_: Instruction, value: str) -> str:
|
||||
return value
|
357
core/src/apps/solana/layout.py
Normal file
357
core/src/apps/solana/layout.py
Normal file
@ -0,0 +1,357 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.crypto import base58
|
||||
from trezor.enums import ButtonRequestType
|
||||
from trezor.strings import format_amount
|
||||
from trezor.ui.layouts import (
|
||||
confirm_metadata,
|
||||
confirm_properties,
|
||||
confirm_solana_tx,
|
||||
confirm_value,
|
||||
)
|
||||
|
||||
from apps.common.paths import address_n_to_str
|
||||
|
||||
from .types import AddressType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.ui.layouts import PropertyType
|
||||
|
||||
from .transaction.instructions import (
|
||||
AssociatedTokenAccountProgramCreateInstruction,
|
||||
Instruction,
|
||||
SystemProgramTransferInstruction,
|
||||
TokenProgramTransferCheckedInstruction,
|
||||
)
|
||||
from .types import AddressReference
|
||||
|
||||
|
||||
def _format_path(path: list[int]) -> str:
|
||||
from micropython import const
|
||||
|
||||
from apps.common.paths import unharden
|
||||
|
||||
if len(path) < 4:
|
||||
return address_n_to_str(path)
|
||||
|
||||
ACCOUNT_PATH_INDEX = const(3)
|
||||
account_index = path[ACCOUNT_PATH_INDEX]
|
||||
return f"Solana #{unharden(account_index) + 1}"
|
||||
|
||||
|
||||
def _get_address_reference_props(address: AddressReference, display_name: str):
|
||||
return (
|
||||
(f"{display_name} is provided via a lookup table.", ""),
|
||||
("Lookup table address:", base58.encode(address[0])),
|
||||
("Account index:", f"{address[1]}"),
|
||||
)
|
||||
|
||||
|
||||
async def confirm_instruction(
|
||||
instruction: Instruction,
|
||||
instructions_count: int,
|
||||
instruction_index: int,
|
||||
signer_path: list[int],
|
||||
signer_public_key: bytes,
|
||||
) -> None:
|
||||
instruction_title = (
|
||||
f"{instruction_index}/{instructions_count}: {instruction.ui_name}"
|
||||
)
|
||||
|
||||
if instruction.is_deprecated_warning is not None:
|
||||
await confirm_metadata(
|
||||
"confirm_deprecated_warning",
|
||||
instruction_title,
|
||||
instruction.is_deprecated_warning,
|
||||
br_code=ButtonRequestType.Other,
|
||||
)
|
||||
|
||||
for ui_property in instruction.ui_properties:
|
||||
if ui_property.parameter is not None:
|
||||
property_template = instruction.get_property_template(ui_property.parameter)
|
||||
value = instruction.parsed_data[ui_property.parameter]
|
||||
|
||||
if property_template.is_authority and signer_public_key == value:
|
||||
continue
|
||||
|
||||
if property_template.is_optional and value is None:
|
||||
continue
|
||||
|
||||
if ui_property.default_value_to_hide == value:
|
||||
continue
|
||||
|
||||
await confirm_properties(
|
||||
"confirm_instruction",
|
||||
f"{instruction_index}/{instructions_count}: {instruction.ui_name}",
|
||||
(
|
||||
(
|
||||
ui_property.display_name,
|
||||
property_template.format(instruction, value),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif ui_property.account is not None:
|
||||
account_template = instruction.get_account_template(ui_property.account)
|
||||
|
||||
# optional account, skip if not present
|
||||
if ui_property.account not in instruction.parsed_accounts:
|
||||
continue
|
||||
|
||||
account_value = instruction.parsed_accounts[ui_property.account]
|
||||
|
||||
if account_template.is_authority:
|
||||
if signer_public_key == account_value[0]:
|
||||
continue
|
||||
|
||||
account_data: list[tuple[str, str]] = []
|
||||
if len(account_value) == 2:
|
||||
signer_suffix = ""
|
||||
if account_value[0] == signer_public_key:
|
||||
signer_suffix = " (Signer)"
|
||||
|
||||
account_data.append(
|
||||
(
|
||||
ui_property.display_name,
|
||||
f"{base58.encode(account_value[0])}{signer_suffix}",
|
||||
)
|
||||
)
|
||||
elif len(account_value) == 3:
|
||||
account_data += _get_address_reference_props(
|
||||
account_value, ui_property.display_name
|
||||
)
|
||||
else:
|
||||
raise ValueError # Invalid account value
|
||||
|
||||
await confirm_properties(
|
||||
"confirm_instruction",
|
||||
f"{instruction_index}/{instructions_count}: {instruction.ui_name}",
|
||||
account_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError # Invalid ui property
|
||||
|
||||
if instruction.multisig_signers:
|
||||
await confirm_metadata(
|
||||
"confirm_multisig",
|
||||
"Confirm multisig",
|
||||
"The following instruction is a multisig instruction.",
|
||||
br_code=ButtonRequestType.Other,
|
||||
)
|
||||
|
||||
signers: list[tuple[str, str]] = []
|
||||
for i, multisig_signer in enumerate(instruction.multisig_signers, 1):
|
||||
multisig_signer_public_key = multisig_signer[0]
|
||||
|
||||
path_str = ""
|
||||
if multisig_signer_public_key == signer_public_key:
|
||||
path_str = f" ({address_n_to_str(signer_path)})"
|
||||
|
||||
signers.append(
|
||||
(f"Signer {i}{path_str}:", base58.encode(multisig_signer[0]))
|
||||
)
|
||||
|
||||
await confirm_properties(
|
||||
"confirm_instruction",
|
||||
f"{instruction_index}/{instructions_count}: {instruction.ui_name}",
|
||||
signers,
|
||||
)
|
||||
|
||||
|
||||
def get_address_type(address_type: AddressType) -> str:
|
||||
if address_type == AddressType.AddressSig:
|
||||
return "(Writable, Signer)"
|
||||
if address_type == AddressType.AddressSigReadOnly:
|
||||
return "(Signer)"
|
||||
if address_type == AddressType.AddressReadOnly:
|
||||
return ""
|
||||
if address_type == AddressType.AddressRw:
|
||||
return "(Writable)"
|
||||
raise ValueError # Invalid address type
|
||||
|
||||
|
||||
async def confirm_unsupported_instruction_details(
|
||||
instruction: Instruction,
|
||||
title: str,
|
||||
signer_path: list[int],
|
||||
signer_public_key: bytes,
|
||||
) -> None:
|
||||
from trezor.ui import NORMAL
|
||||
from trezor.ui.layouts import confirm_properties, should_show_more
|
||||
|
||||
should_show_instruction_details = await should_show_more(
|
||||
title,
|
||||
(
|
||||
(
|
||||
NORMAL,
|
||||
f"Instruction contains {len(instruction.accounts)} accounts and its data is {len(instruction.instruction_data)} bytes long.",
|
||||
),
|
||||
),
|
||||
"Show details",
|
||||
confirm="Continue",
|
||||
)
|
||||
|
||||
if should_show_instruction_details:
|
||||
await confirm_properties(
|
||||
"instruction_data",
|
||||
title,
|
||||
(("Instruction data:", bytes(instruction.instruction_data)),),
|
||||
)
|
||||
|
||||
accounts = []
|
||||
for i, account in enumerate(instruction.accounts, 1):
|
||||
if len(account) == 2:
|
||||
account_public_key = account[0]
|
||||
address_type = get_address_type(account[1])
|
||||
|
||||
path_str = ""
|
||||
if account_public_key == signer_public_key:
|
||||
path_str = f" ({address_n_to_str(signer_path)})"
|
||||
|
||||
accounts.append(
|
||||
(
|
||||
f"Account {i}{path_str} {address_type}:",
|
||||
base58.encode(account_public_key),
|
||||
)
|
||||
)
|
||||
elif len(account) == 3:
|
||||
address_type = get_address_type(account[2])
|
||||
accounts += _get_address_reference_props(
|
||||
account, f"Account {i} {address_type}"
|
||||
)
|
||||
else:
|
||||
raise ValueError # Invalid account value
|
||||
|
||||
await confirm_properties(
|
||||
"accounts",
|
||||
title,
|
||||
accounts,
|
||||
)
|
||||
|
||||
|
||||
async def confirm_unsupported_instruction_confirm(
|
||||
instruction: Instruction,
|
||||
instructions_count: int,
|
||||
instruction_index: int,
|
||||
signer_path: list[int],
|
||||
signer_public_key: bytes,
|
||||
) -> None:
|
||||
formatted_instruction_id = (
|
||||
instruction.instruction_id if instruction.instruction_id is not None else "N/A"
|
||||
)
|
||||
title = f"{instruction_index}/{instructions_count}: {instruction.ui_name}: instruction id ({formatted_instruction_id})"
|
||||
|
||||
return await confirm_unsupported_instruction_details(
|
||||
instruction, title, signer_path, signer_public_key
|
||||
)
|
||||
|
||||
|
||||
async def confirm_unsupported_program_confirm(
|
||||
instruction: Instruction,
|
||||
instructions_count: int,
|
||||
instruction_index: int,
|
||||
signer_path: list[int],
|
||||
signer_public_key: bytes,
|
||||
) -> None:
|
||||
title = f"{instruction_index}/{instructions_count}: {instruction.ui_name}"
|
||||
|
||||
return await confirm_unsupported_instruction_details(
|
||||
instruction, title, signer_path, signer_public_key
|
||||
)
|
||||
|
||||
|
||||
async def confirm_system_transfer(
|
||||
transfer_instruction: SystemProgramTransferInstruction,
|
||||
fee: int,
|
||||
signer_path: list[int],
|
||||
blockhash: bytes,
|
||||
) -> None:
|
||||
await confirm_value(
|
||||
title="Recipient",
|
||||
value=base58.encode(transfer_instruction.recipient_account[0]),
|
||||
description="",
|
||||
br_type="confirm_recipient",
|
||||
br_code=ButtonRequestType.ConfirmOutput,
|
||||
verb="CONTINUE",
|
||||
)
|
||||
|
||||
await confirm_custom_transaction(
|
||||
transfer_instruction.lamports,
|
||||
9,
|
||||
"SOL",
|
||||
fee,
|
||||
signer_path,
|
||||
blockhash,
|
||||
)
|
||||
|
||||
|
||||
async def confirm_token_transfer(
|
||||
create_token_account_instruction: AssociatedTokenAccountProgramCreateInstruction
|
||||
| None,
|
||||
transfer_token_instruction: TokenProgramTransferCheckedInstruction,
|
||||
fee: int,
|
||||
signer_path: list[int],
|
||||
blockhash: bytes,
|
||||
):
|
||||
recipient_props: list[PropertyType] = [
|
||||
("", base58.encode(transfer_token_instruction.destination_account[0]))
|
||||
]
|
||||
if create_token_account_instruction is not None:
|
||||
recipient_props.append(("(account will be created)", ""))
|
||||
|
||||
await confirm_properties(
|
||||
"confirm_recipient",
|
||||
"Recipient",
|
||||
recipient_props,
|
||||
)
|
||||
|
||||
await confirm_value(
|
||||
title="Token address",
|
||||
value=base58.encode(transfer_token_instruction.token_mint[0]),
|
||||
description="",
|
||||
br_type="confirm_token_address",
|
||||
br_code=ButtonRequestType.ConfirmOutput,
|
||||
verb="CONTINUE",
|
||||
)
|
||||
|
||||
await confirm_custom_transaction(
|
||||
transfer_token_instruction.amount,
|
||||
transfer_token_instruction.decimals,
|
||||
"[TOKEN]",
|
||||
fee,
|
||||
signer_path,
|
||||
blockhash,
|
||||
)
|
||||
|
||||
|
||||
async def confirm_custom_transaction(
|
||||
amount: int,
|
||||
decimals: int,
|
||||
unit: str,
|
||||
fee: int,
|
||||
signer_path: list[int],
|
||||
blockhash: bytes,
|
||||
) -> None:
|
||||
await confirm_solana_tx(
|
||||
amount=f"{format_amount(amount, decimals)} {unit}",
|
||||
fee=f"{format_amount(fee, 9)} SOL",
|
||||
fee_title="Expected fee:",
|
||||
items=(
|
||||
("Account:", _format_path(signer_path)),
|
||||
("Blockhash:", base58.encode(blockhash)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def confirm_transaction(
|
||||
signer_path: list[int], blockhash: bytes, fee: int
|
||||
) -> None:
|
||||
await confirm_solana_tx(
|
||||
amount="",
|
||||
amount_title="",
|
||||
fee=f"{format_amount(fee, 9)} SOL",
|
||||
fee_title="Expected fee:",
|
||||
items=(
|
||||
("Account:", _format_path(signer_path)),
|
||||
("Blockhash:", base58.encode(blockhash)),
|
||||
),
|
||||
)
|
@ -1,8 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.wire import DataError
|
||||
|
||||
from apps.common.keychain import with_slip44_keychain
|
||||
|
||||
from . import CURVE, PATTERNS, SLIP44_ID
|
||||
from .transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import SolanaSignTx, SolanaTxSignature
|
||||
@ -16,12 +19,178 @@ async def sign_tx(
|
||||
keychain: Keychain,
|
||||
) -> SolanaTxSignature:
|
||||
from trezor.crypto.curve import ed25519
|
||||
from trezor.enums import ButtonRequestType
|
||||
from trezor.messages import SolanaTxSignature
|
||||
from trezor.ui.layouts import confirm_metadata, show_warning
|
||||
|
||||
address_n = msg.address_n
|
||||
serialized_tx = msg.serialized_tx
|
||||
from apps.common import seed
|
||||
|
||||
from .layout import confirm_transaction
|
||||
|
||||
address_n = msg.address_n # local_cache_attribute
|
||||
serialized_tx = msg.serialized_tx # local_cache_attribute
|
||||
|
||||
node = keychain.derive(address_n)
|
||||
signer_public_key = seed.remove_ed25519_prefix(node.public_key())
|
||||
|
||||
try:
|
||||
transaction: Transaction = Transaction(serialized_tx)
|
||||
except Exception:
|
||||
raise DataError("Invalid transaction")
|
||||
|
||||
if transaction.blind_signing:
|
||||
await show_warning(
|
||||
"warning_blind_signing", "Transaction contains unknown instructions."
|
||||
)
|
||||
|
||||
if transaction.required_signers_count > 1:
|
||||
await confirm_metadata(
|
||||
"multiple_signers",
|
||||
"Multiple signers",
|
||||
f"Transaction requires {transaction.required_signers_count} signers which increases the fee.",
|
||||
br_code=ButtonRequestType.Other,
|
||||
)
|
||||
|
||||
fee = calculate_fee(transaction)
|
||||
|
||||
if not await try_confirm_predefined_transaction(
|
||||
transaction, fee, address_n, transaction.blockhash
|
||||
):
|
||||
await confirm_instructions(address_n, signer_public_key, transaction)
|
||||
await confirm_transaction(
|
||||
address_n,
|
||||
transaction.blockhash,
|
||||
calculate_fee(transaction),
|
||||
)
|
||||
|
||||
signature = ed25519.sign(node.private_key(), serialized_tx)
|
||||
|
||||
return SolanaTxSignature(signature=signature)
|
||||
|
||||
|
||||
async def try_confirm_predefined_transaction(
|
||||
transaction: Transaction, fee: int, signer_path: list[int], blockhash: bytes
|
||||
) -> bool:
|
||||
from .layout import confirm_system_transfer, confirm_token_transfer
|
||||
from .transaction.instructions import (
|
||||
AssociatedTokenAccountProgramCreateInstruction,
|
||||
SystemProgramTransferInstruction,
|
||||
TokenProgramTransferCheckedInstruction,
|
||||
)
|
||||
|
||||
instructions = transaction.instructions
|
||||
instructions_count = len(instructions)
|
||||
|
||||
if instructions_count == 1:
|
||||
if SystemProgramTransferInstruction.is_type_of(instructions[0]):
|
||||
await confirm_system_transfer(instructions[0], fee, signer_path, blockhash)
|
||||
return True
|
||||
|
||||
if TokenProgramTransferCheckedInstruction.is_type_of(instructions[0]):
|
||||
await confirm_token_transfer(
|
||||
None, instructions[0], fee, signer_path, blockhash
|
||||
)
|
||||
return True
|
||||
elif instructions_count == 2:
|
||||
if AssociatedTokenAccountProgramCreateInstruction.is_type_of(
|
||||
instructions[0]
|
||||
) and TokenProgramTransferCheckedInstruction.is_type_of(instructions[1]):
|
||||
create_token_account_instruction = instructions[0]
|
||||
transfer_token_instruction = instructions[1]
|
||||
|
||||
# If the account being created is different from the recipient account we need
|
||||
# to display all the instruction information.
|
||||
if (
|
||||
create_token_account_instruction.associated_token_account[0]
|
||||
!= transfer_token_instruction.destination_account[0]
|
||||
):
|
||||
return False
|
||||
|
||||
await confirm_token_transfer(
|
||||
instructions[0],
|
||||
instructions[1],
|
||||
fee,
|
||||
signer_path,
|
||||
blockhash,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def confirm_instructions(
|
||||
signer_path: list[int], signer_public_key: bytes, transaction: Transaction
|
||||
) -> None:
|
||||
instructions_count = len(transaction.instructions)
|
||||
for instruction_index, instruction in enumerate(transaction.instructions, 1):
|
||||
if not instruction.is_program_supported:
|
||||
from .layout import confirm_unsupported_program_confirm
|
||||
|
||||
await confirm_unsupported_program_confirm(
|
||||
instruction,
|
||||
instructions_count,
|
||||
instruction_index,
|
||||
signer_path,
|
||||
signer_public_key,
|
||||
)
|
||||
elif not instruction.is_instruction_supported:
|
||||
from .layout import confirm_unsupported_instruction_confirm
|
||||
|
||||
await confirm_unsupported_instruction_confirm(
|
||||
instruction,
|
||||
instructions_count,
|
||||
instruction_index,
|
||||
signer_path,
|
||||
signer_public_key,
|
||||
)
|
||||
else:
|
||||
from .layout import confirm_instruction
|
||||
|
||||
await confirm_instruction(
|
||||
instruction,
|
||||
instructions_count,
|
||||
instruction_index,
|
||||
signer_path,
|
||||
signer_public_key,
|
||||
)
|
||||
|
||||
|
||||
def calculate_fee(transaction: Transaction) -> int:
|
||||
from .constants import SOLANA_BASE_FEE_LAMPORTS, SOLANA_COMPUTE_UNIT_LIMIT
|
||||
from .transaction.instructions import (
|
||||
COMPUTE_BUDGET_PROGRAM_ID,
|
||||
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT,
|
||||
COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE,
|
||||
)
|
||||
from .types import AddressType
|
||||
|
||||
number_of_signers = 0
|
||||
for address in transaction.addresses:
|
||||
if address[1] == AddressType.AddressSig:
|
||||
number_of_signers += 1
|
||||
|
||||
base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers
|
||||
|
||||
unit_price = 0
|
||||
is_unit_price_set = False
|
||||
unit_limit = SOLANA_COMPUTE_UNIT_LIMIT
|
||||
is_unit_limit_set = False
|
||||
|
||||
for instruction in transaction.instructions[:3]:
|
||||
if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID:
|
||||
if (
|
||||
instruction.instruction_id
|
||||
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT
|
||||
and not is_unit_limit_set
|
||||
):
|
||||
unit_limit = instruction.units
|
||||
is_unit_limit_set = True
|
||||
elif (
|
||||
instruction.instruction_id
|
||||
== COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE
|
||||
and not is_unit_price_set
|
||||
):
|
||||
unit_price = instruction.lamports
|
||||
is_unit_price_set = True
|
||||
|
||||
return int(base_fee + unit_price * unit_limit / 1000000)
|
||||
|
204
core/src/apps/solana/transaction/__init__.py
Normal file
204
core/src/apps/solana/transaction/__init__.py
Normal file
@ -0,0 +1,204 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.crypto import base58
|
||||
from trezor.utils import BufferReader
|
||||
from trezor.wire import DataError
|
||||
|
||||
from ..types import AddressType
|
||||
from .instruction import Instruction
|
||||
from .instructions import get_instruction, get_instruction_id_length
|
||||
from .parse import parse_block_hash, parse_pubkey, parse_var_int
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Account, Address, AddressReference, RawInstruction
|
||||
|
||||
|
||||
class Transaction:
|
||||
blind_signing = False
|
||||
required_signers_count = 0
|
||||
|
||||
version: int | None = None
|
||||
|
||||
addresses: list[Address]
|
||||
|
||||
blockhash: bytes
|
||||
|
||||
raw_instructions: list[RawInstruction]
|
||||
instructions: list[Instruction]
|
||||
|
||||
address_lookup_tables_rw_addresses: list[AddressReference]
|
||||
address_lookup_tables_ro_addresses: list[AddressReference]
|
||||
|
||||
def __init__(self, serialized_tx: bytes) -> None:
|
||||
self._parse_transaction(serialized_tx)
|
||||
self._create_instructions()
|
||||
self._determine_if_blind_signing()
|
||||
|
||||
def _parse_transaction(self, serialized_tx: bytes) -> None:
|
||||
serialized_tx_reader = BufferReader(serialized_tx)
|
||||
self._parse_header(serialized_tx_reader)
|
||||
|
||||
self._parse_addresses(serialized_tx_reader)
|
||||
|
||||
self.blockhash = parse_block_hash(serialized_tx_reader)
|
||||
|
||||
self._parse_instructions(serialized_tx_reader)
|
||||
|
||||
self._parse_address_lookup_tables(serialized_tx_reader)
|
||||
|
||||
if serialized_tx_reader.remaining_count() != 0:
|
||||
raise DataError("Invalid transaction")
|
||||
|
||||
def _parse_header(self, serialized_tx_reader: BufferReader) -> None:
|
||||
self.version: int | None = None
|
||||
|
||||
if serialized_tx_reader.peek() & 0b10000000:
|
||||
self.version = serialized_tx_reader.get() & 0b01111111
|
||||
# only version 0 is supported
|
||||
if self.version > 0:
|
||||
raise DataError("Unsupported transaction version")
|
||||
|
||||
self.required_signers_count: int = serialized_tx_reader.get()
|
||||
self.num_signature_read_only_addresses: int = serialized_tx_reader.get()
|
||||
self.num_read_only_addresses: int = serialized_tx_reader.get()
|
||||
|
||||
def _parse_addresses(self, serialized_tx_reader: BufferReader) -> None:
|
||||
num_of_addresses = parse_var_int(serialized_tx_reader)
|
||||
|
||||
assert (
|
||||
num_of_addresses
|
||||
>= self.required_signers_count
|
||||
+ self.num_signature_read_only_addresses
|
||||
+ self.num_read_only_addresses
|
||||
)
|
||||
|
||||
addresses: list[Address] = []
|
||||
for i in range(num_of_addresses):
|
||||
if i < self.required_signers_count:
|
||||
type = AddressType.AddressSig
|
||||
elif (
|
||||
i < self.required_signers_count + self.num_signature_read_only_addresses
|
||||
):
|
||||
type = AddressType.AddressSigReadOnly
|
||||
elif (
|
||||
i
|
||||
< self.required_signers_count
|
||||
+ self.num_signature_read_only_addresses
|
||||
+ self.num_read_only_addresses
|
||||
):
|
||||
type = AddressType.AddressRw
|
||||
else:
|
||||
type = AddressType.AddressReadOnly
|
||||
|
||||
address = parse_pubkey(serialized_tx_reader)
|
||||
|
||||
addresses.append((address, type))
|
||||
|
||||
self.addresses = addresses
|
||||
|
||||
def _parse_instructions(self, serialized_tx_reader: BufferReader) -> None:
|
||||
num_of_instructions = parse_var_int(serialized_tx_reader)
|
||||
|
||||
self.raw_instructions = []
|
||||
|
||||
for _ in range(num_of_instructions):
|
||||
program_index = serialized_tx_reader.get()
|
||||
program_id = base58.encode(self.addresses[program_index][0])
|
||||
num_of_accounts = parse_var_int(serialized_tx_reader)
|
||||
accounts: list[int] = []
|
||||
for _ in range(num_of_accounts):
|
||||
account_index = serialized_tx_reader.get()
|
||||
accounts.append(account_index)
|
||||
|
||||
data_length = parse_var_int(serialized_tx_reader)
|
||||
|
||||
instruction_id_length = get_instruction_id_length(program_id)
|
||||
if instruction_id_length <= data_length:
|
||||
instruction_id = int.from_bytes(
|
||||
serialized_tx_reader.read_memoryview(instruction_id_length),
|
||||
"little",
|
||||
)
|
||||
else:
|
||||
instruction_id = None
|
||||
|
||||
instruction_data = serialized_tx_reader.read_memoryview(
|
||||
max(0, data_length - instruction_id_length)
|
||||
)
|
||||
|
||||
self.raw_instructions.append(
|
||||
(program_index, instruction_id, accounts, instruction_data)
|
||||
)
|
||||
|
||||
def _parse_address_lookup_tables(self, serialized_tx: BufferReader) -> None:
|
||||
self.address_lookup_tables_rw_addresses = []
|
||||
self.address_lookup_tables_ro_addresses = []
|
||||
|
||||
if self.version is None:
|
||||
return
|
||||
|
||||
address_lookup_tables_count = parse_var_int(serialized_tx)
|
||||
for _ in range(address_lookup_tables_count):
|
||||
account = parse_pubkey(serialized_tx)
|
||||
|
||||
table_rw_indexes_count = parse_var_int(serialized_tx)
|
||||
for _ in range(table_rw_indexes_count):
|
||||
index = serialized_tx.get()
|
||||
self.address_lookup_tables_rw_addresses.append(
|
||||
(account, index, AddressType.AddressRw)
|
||||
)
|
||||
|
||||
table_ro_indexes_count = parse_var_int(serialized_tx)
|
||||
for _ in range(table_ro_indexes_count):
|
||||
index = serialized_tx.get()
|
||||
self.address_lookup_tables_ro_addresses.append(
|
||||
(account, index, AddressType.AddressReadOnly)
|
||||
)
|
||||
|
||||
def _get_combined_accounts(self) -> list[Account]:
|
||||
accounts: list[Account] = []
|
||||
for address in self.addresses:
|
||||
accounts.append(address)
|
||||
|
||||
for rw_address in self.address_lookup_tables_rw_addresses:
|
||||
accounts.append(rw_address)
|
||||
for ro_address in self.address_lookup_tables_ro_addresses:
|
||||
accounts.append(ro_address)
|
||||
|
||||
return accounts
|
||||
|
||||
def _create_instructions(self) -> None:
|
||||
# Instructions reference accounts by index in this combined list.
|
||||
combined_accounts = (
|
||||
self.addresses # type: ignore [Operator "+" not supported for types "list[Address]" and "list[AddressReference]"]
|
||||
+ self.address_lookup_tables_rw_addresses
|
||||
+ self.address_lookup_tables_ro_addresses
|
||||
)
|
||||
|
||||
self.instructions = []
|
||||
for (
|
||||
program_index,
|
||||
instruction_id,
|
||||
accounts,
|
||||
instruction_data,
|
||||
) in self.raw_instructions:
|
||||
program_id = base58.encode(self.addresses[program_index][0])
|
||||
instruction_accounts = [
|
||||
combined_accounts[account_index] for account_index in accounts
|
||||
]
|
||||
instruction = get_instruction(
|
||||
program_id,
|
||||
instruction_id,
|
||||
instruction_accounts,
|
||||
instruction_data,
|
||||
)
|
||||
|
||||
self.instructions.append(instruction)
|
||||
|
||||
def _determine_if_blind_signing(self) -> None:
|
||||
for instruction in self.instructions:
|
||||
if (
|
||||
not instruction.is_program_supported
|
||||
or not instruction.is_instruction_supported
|
||||
):
|
||||
self.blind_signing = True
|
||||
break
|
154
core/src/apps/solana/transaction/instruction.py
Normal file
154
core/src/apps/solana/transaction/instruction.py
Normal file
@ -0,0 +1,154 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..types import (
|
||||
Account,
|
||||
AccountTemplate,
|
||||
InstructionData,
|
||||
PropertyTemplate,
|
||||
UIProperty,
|
||||
)
|
||||
|
||||
|
||||
class Instruction:
|
||||
program_id: str
|
||||
instruction_id: int | None
|
||||
|
||||
property_templates: list[PropertyTemplate]
|
||||
accounts_template: list[AccountTemplate]
|
||||
|
||||
ui_name: str
|
||||
|
||||
ui_properties: list[UIProperty]
|
||||
|
||||
parsed_data: dict[str, Any]
|
||||
parsed_accounts: dict[str, Account]
|
||||
|
||||
is_program_supported: bool
|
||||
is_instruction_supported: bool
|
||||
instruction_data: InstructionData
|
||||
accounts: list[Account]
|
||||
|
||||
multisig_signers: list[Account]
|
||||
|
||||
is_deprecated_warning: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def parse_instruction_data(
|
||||
instruction_data: InstructionData, property_templates: list[PropertyTemplate]
|
||||
):
|
||||
from trezor.utils import BufferReader
|
||||
from trezor.wire import DataError
|
||||
|
||||
reader = BufferReader(instruction_data)
|
||||
|
||||
parsed_data = {}
|
||||
for property_template in property_templates:
|
||||
is_included = True
|
||||
if property_template.is_optional:
|
||||
is_included = True if reader.get() == 1 else False
|
||||
|
||||
parsed_data[property_template.name] = (
|
||||
property_template.parse(reader) if is_included else None
|
||||
)
|
||||
|
||||
if reader.remaining_count() != 0:
|
||||
raise DataError("Invalid transaction data")
|
||||
|
||||
return parsed_data
|
||||
|
||||
@staticmethod
|
||||
def parse_instruction_accounts(
|
||||
accounts: list[Account], accounts_template: list[AccountTemplate]
|
||||
):
|
||||
parsed_account = {}
|
||||
for i, account_template in enumerate(accounts_template):
|
||||
if i >= len(accounts):
|
||||
if account_template.optional:
|
||||
continue
|
||||
else:
|
||||
raise ValueError # "Account is missing
|
||||
|
||||
parsed_account[account_template.name] = accounts[i]
|
||||
return parsed_account
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_data: InstructionData,
|
||||
program_id: str,
|
||||
accounts: list[Account],
|
||||
instruction_id: int | None,
|
||||
property_templates: list[PropertyTemplate],
|
||||
accounts_template: list[AccountTemplate],
|
||||
ui_properties: list[UIProperty],
|
||||
ui_name: str,
|
||||
is_program_supported: bool = True,
|
||||
is_instruction_supported: bool = True,
|
||||
supports_multisig: bool = False,
|
||||
is_deprecated_warning: str | None = None,
|
||||
) -> None:
|
||||
self.program_id = program_id
|
||||
self.instruction_id = instruction_id
|
||||
|
||||
self.property_templates = property_templates
|
||||
self.accounts_template = accounts_template
|
||||
|
||||
self.ui_name = ui_name
|
||||
|
||||
self.ui_properties = ui_properties
|
||||
|
||||
self.is_program_supported = is_program_supported
|
||||
self.is_instruction_supported = is_instruction_supported
|
||||
|
||||
self.is_deprecated_warning = is_deprecated_warning
|
||||
|
||||
self.instruction_data = instruction_data
|
||||
self.accounts = accounts
|
||||
|
||||
if self.is_instruction_supported:
|
||||
self.parsed_data = self.parse_instruction_data(
|
||||
instruction_data, property_templates
|
||||
)
|
||||
|
||||
self.parsed_accounts = self.parse_instruction_accounts(
|
||||
accounts, accounts_template
|
||||
)
|
||||
|
||||
self.multisig_signers = accounts[len(accounts_template) :]
|
||||
if self.multisig_signers and not supports_multisig:
|
||||
raise ValueError # Multisig not supported
|
||||
else:
|
||||
self.parsed_data = {}
|
||||
self.parsed_accounts = {}
|
||||
self.multisig_signers = []
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr in self.parsed_data:
|
||||
return self.parsed_data[attr]
|
||||
if attr in self.parsed_accounts:
|
||||
return self.parsed_accounts[attr]
|
||||
|
||||
raise AttributeError # Attribute not found
|
||||
|
||||
def get_property_template(self, property: str) -> PropertyTemplate:
|
||||
for property_template in self.property_templates:
|
||||
if property_template.name == property:
|
||||
return property_template
|
||||
|
||||
raise ValueError # Property not found
|
||||
|
||||
def get_account_template(self, account_name: str) -> AccountTemplate:
|
||||
for account_template in self.accounts_template:
|
||||
if account_template.name == account_name:
|
||||
return account_template
|
||||
|
||||
raise ValueError # Account not found
|
||||
|
||||
@classmethod
|
||||
def is_type_of(cls, ins: Any) -> TypeGuard[Self]:
|
||||
# gets overridden in `instructions.py` `FakeClass`
|
||||
raise NotImplementedError
|
50
core/src/apps/solana/transaction/parse.py
Normal file
50
core/src/apps/solana/transaction/parse.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from apps.common.readers import read_uint64_le
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.utils import BufferReader
|
||||
|
||||
|
||||
def parse_var_int(serialized_tx: BufferReader) -> int:
|
||||
value = 0
|
||||
shift = 0
|
||||
while serialized_tx.remaining_count():
|
||||
B = serialized_tx.get()
|
||||
value += (B & 0b01111111) << shift
|
||||
shift += 7
|
||||
if B & 0b10000000 == 0:
|
||||
break
|
||||
|
||||
if value > 0xFFFF:
|
||||
raise ValueError # compact-u16 value too large
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def parse_block_hash(serialized_tx: BufferReader) -> bytes:
|
||||
return bytes(serialized_tx.read_memoryview(32))
|
||||
|
||||
|
||||
def parse_pubkey(serialized_tx: BufferReader) -> bytes:
|
||||
return bytes(serialized_tx.read_memoryview(32))
|
||||
|
||||
|
||||
def parse_enum(serialized_tx: BufferReader) -> int:
|
||||
return serialized_tx.get()
|
||||
|
||||
|
||||
def parse_string(serialized_tx: BufferReader) -> str:
|
||||
# TODO SOL: validation shall be checked (length is less than 2^32 or even less)
|
||||
length = read_uint64_le(serialized_tx)
|
||||
return bytes(serialized_tx.read_memoryview(length)).decode("utf-8")
|
||||
|
||||
|
||||
def parse_memo(serialized_tx: BufferReader) -> str:
|
||||
return bytes(serialized_tx.read_memoryview(serialized_tx.remaining_count())).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
|
||||
def parse_byte(serialized_tx: BufferReader) -> int:
|
||||
return serialized_tx.get()
|
73
core/src/apps/solana/types.py
Normal file
73
core/src/apps/solana/types.py
Normal file
@ -0,0 +1,73 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
from trezor.utils import BufferReader
|
||||
|
||||
from .transaction import Instruction
|
||||
|
||||
Address = tuple[bytes, "AddressType"]
|
||||
AddressReference = tuple[bytes, int, "AddressType"]
|
||||
Account = Address | AddressReference
|
||||
|
||||
ProgramIndex = int
|
||||
InstructionId = int | None
|
||||
AccountIndex = int
|
||||
InstructionData = memoryview
|
||||
RawInstruction = tuple[
|
||||
ProgramIndex, InstructionId, list[AccountIndex], InstructionData
|
||||
]
|
||||
|
||||
T = TypeVar("T")
|
||||
else:
|
||||
IntEnum = object
|
||||
T = 0
|
||||
Generic = {T: object}
|
||||
|
||||
|
||||
class AddressType(IntEnum):
|
||||
AddressSig = 0
|
||||
AddressSigReadOnly = 1
|
||||
AddressReadOnly = 2
|
||||
AddressRw = 3
|
||||
|
||||
|
||||
class PropertyTemplate(Generic[T]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
is_authority: bool,
|
||||
is_optional: bool,
|
||||
parse: Callable[[BufferReader], T],
|
||||
format: Callable[[Instruction, T], str],
|
||||
):
|
||||
self.name = name
|
||||
self.is_authority = is_authority
|
||||
self.is_optional = is_optional
|
||||
self.parse = parse
|
||||
self.format = format
|
||||
|
||||
|
||||
class AccountTemplate:
|
||||
def __init__(self, name: str, is_authority: bool, optional: bool):
|
||||
self.name = name
|
||||
self.is_authority = is_authority
|
||||
self.optional = optional
|
||||
|
||||
|
||||
class UIProperty:
|
||||
def __init__(
|
||||
self,
|
||||
parameter: str | None,
|
||||
account: str | None,
|
||||
display_name: str,
|
||||
is_authority: bool,
|
||||
default_value_to_hide: Any | None,
|
||||
) -> None:
|
||||
self.parameter = parameter
|
||||
self.account = account
|
||||
self.display_name = display_name
|
||||
self.is_authority = is_authority
|
||||
self.default_value_to_hide = default_value_to_hide
|
Loading…
Reference in New Issue
Block a user