diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 58278aa7e2..4c9748e504 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -683,14 +683,28 @@ if not utils.BITCOIN_ONLY: import apps.ripple.sign_tx apps.solana import apps.solana + apps.solana.constants + import apps.solana.constants + apps.solana.format + import apps.solana.format apps.solana.get_address import apps.solana.get_address apps.solana.get_public_key import apps.solana.get_public_key + apps.solana.layout + import apps.solana.layout apps.solana.sign_tx import apps.solana.sign_tx + apps.solana.transaction + import apps.solana.transaction + apps.solana.transaction.instruction + import apps.solana.transaction.instruction apps.solana.transaction.instructions import apps.solana.transaction.instructions + apps.solana.transaction.parse + import apps.solana.transaction.parse + apps.solana.types + import apps.solana.types apps.stellar import apps.stellar apps.stellar.consts diff --git a/core/src/apps/solana/constants.py b/core/src/apps/solana/constants.py new file mode 100644 index 0000000000..79d6115b5c --- /dev/null +++ b/core/src/apps/solana/constants.py @@ -0,0 +1,6 @@ +from micropython import const + +ADDRESS_SIZE = const(32) + +SOLANA_BASE_FEE_LAMPORTS = const(5000) +SOLANA_COMPUTE_UNIT_LIMIT = const(200000) diff --git a/core/src/apps/solana/format.py b/core/src/apps/solana/format.py new file mode 100644 index 0000000000..0242b47b48 --- /dev/null +++ b/core/src/apps/solana/format.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from trezor.strings import format_amount, format_timestamp + +if TYPE_CHECKING: + from .transaction.instructions import Instruction + + +def format_pubkey(_: Instruction, value: bytes | None) -> str: + from trezor.crypto import base58 + + if value is None: + raise ValueError # should not be called with optional pubkey + + return base58.encode(value) + + +def format_lamports(_: Instruction, value: int) -> str: + formatted = format_amount(value, decimals=9) + return f"{formatted} SOL" + + +def format_token_amount(instruction: Instruction, value: int) -> str: + assert hasattr(instruction, "decimals") # enforced in instructions.py.mako + + formatted = format_amount(value, decimals=instruction.decimals) + return f"{formatted}" + + +def format_unix_timestamp(_: Instruction, value: int) -> str: + return format_timestamp(value) + + +def format_int(_: Instruction, value: int) -> str: + return str(value) + + +def format_identity(_: Instruction, value: str) -> str: + return value diff --git a/core/src/apps/solana/layout.py b/core/src/apps/solana/layout.py new file mode 100644 index 0000000000..9550f4a79c --- /dev/null +++ b/core/src/apps/solana/layout.py @@ -0,0 +1,357 @@ +from typing import TYPE_CHECKING + +from trezor.crypto import base58 +from trezor.enums import ButtonRequestType +from trezor.strings import format_amount +from trezor.ui.layouts import ( + confirm_metadata, + confirm_properties, + confirm_solana_tx, + confirm_value, +) + +from apps.common.paths import address_n_to_str + +from .types import AddressType + +if TYPE_CHECKING: + from trezor.ui.layouts import PropertyType + + from .transaction.instructions import ( + AssociatedTokenAccountProgramCreateInstruction, + Instruction, + SystemProgramTransferInstruction, + TokenProgramTransferCheckedInstruction, + ) + from .types import AddressReference + + +def _format_path(path: list[int]) -> str: + from micropython import const + + from apps.common.paths import unharden + + if len(path) < 4: + return address_n_to_str(path) + + ACCOUNT_PATH_INDEX = const(3) + account_index = path[ACCOUNT_PATH_INDEX] + return f"Solana #{unharden(account_index) + 1}" + + +def _get_address_reference_props(address: AddressReference, display_name: str): + return ( + (f"{display_name} is provided via a lookup table.", ""), + ("Lookup table address:", base58.encode(address[0])), + ("Account index:", f"{address[1]}"), + ) + + +async def confirm_instruction( + instruction: Instruction, + instructions_count: int, + instruction_index: int, + signer_path: list[int], + signer_public_key: bytes, +) -> None: + instruction_title = ( + f"{instruction_index}/{instructions_count}: {instruction.ui_name}" + ) + + if instruction.is_deprecated_warning is not None: + await confirm_metadata( + "confirm_deprecated_warning", + instruction_title, + instruction.is_deprecated_warning, + br_code=ButtonRequestType.Other, + ) + + for ui_property in instruction.ui_properties: + if ui_property.parameter is not None: + property_template = instruction.get_property_template(ui_property.parameter) + value = instruction.parsed_data[ui_property.parameter] + + if property_template.is_authority and signer_public_key == value: + continue + + if property_template.is_optional and value is None: + continue + + if ui_property.default_value_to_hide == value: + continue + + await confirm_properties( + "confirm_instruction", + f"{instruction_index}/{instructions_count}: {instruction.ui_name}", + ( + ( + ui_property.display_name, + property_template.format(instruction, value), + ), + ), + ) + elif ui_property.account is not None: + account_template = instruction.get_account_template(ui_property.account) + + # optional account, skip if not present + if ui_property.account not in instruction.parsed_accounts: + continue + + account_value = instruction.parsed_accounts[ui_property.account] + + if account_template.is_authority: + if signer_public_key == account_value[0]: + continue + + account_data: list[tuple[str, str]] = [] + if len(account_value) == 2: + signer_suffix = "" + if account_value[0] == signer_public_key: + signer_suffix = " (Signer)" + + account_data.append( + ( + ui_property.display_name, + f"{base58.encode(account_value[0])}{signer_suffix}", + ) + ) + elif len(account_value) == 3: + account_data += _get_address_reference_props( + account_value, ui_property.display_name + ) + else: + raise ValueError # Invalid account value + + await confirm_properties( + "confirm_instruction", + f"{instruction_index}/{instructions_count}: {instruction.ui_name}", + account_data, + ) + else: + raise ValueError # Invalid ui property + + if instruction.multisig_signers: + await confirm_metadata( + "confirm_multisig", + "Confirm multisig", + "The following instruction is a multisig instruction.", + br_code=ButtonRequestType.Other, + ) + + signers: list[tuple[str, str]] = [] + for i, multisig_signer in enumerate(instruction.multisig_signers, 1): + multisig_signer_public_key = multisig_signer[0] + + path_str = "" + if multisig_signer_public_key == signer_public_key: + path_str = f" ({address_n_to_str(signer_path)})" + + signers.append( + (f"Signer {i}{path_str}:", base58.encode(multisig_signer[0])) + ) + + await confirm_properties( + "confirm_instruction", + f"{instruction_index}/{instructions_count}: {instruction.ui_name}", + signers, + ) + + +def get_address_type(address_type: AddressType) -> str: + if address_type == AddressType.AddressSig: + return "(Writable, Signer)" + if address_type == AddressType.AddressSigReadOnly: + return "(Signer)" + if address_type == AddressType.AddressReadOnly: + return "" + if address_type == AddressType.AddressRw: + return "(Writable)" + raise ValueError # Invalid address type + + +async def confirm_unsupported_instruction_details( + instruction: Instruction, + title: str, + signer_path: list[int], + signer_public_key: bytes, +) -> None: + from trezor.ui import NORMAL + from trezor.ui.layouts import confirm_properties, should_show_more + + should_show_instruction_details = await should_show_more( + title, + ( + ( + NORMAL, + f"Instruction contains {len(instruction.accounts)} accounts and its data is {len(instruction.instruction_data)} bytes long.", + ), + ), + "Show details", + confirm="Continue", + ) + + if should_show_instruction_details: + await confirm_properties( + "instruction_data", + title, + (("Instruction data:", bytes(instruction.instruction_data)),), + ) + + accounts = [] + for i, account in enumerate(instruction.accounts, 1): + if len(account) == 2: + account_public_key = account[0] + address_type = get_address_type(account[1]) + + path_str = "" + if account_public_key == signer_public_key: + path_str = f" ({address_n_to_str(signer_path)})" + + accounts.append( + ( + f"Account {i}{path_str} {address_type}:", + base58.encode(account_public_key), + ) + ) + elif len(account) == 3: + address_type = get_address_type(account[2]) + accounts += _get_address_reference_props( + account, f"Account {i} {address_type}" + ) + else: + raise ValueError # Invalid account value + + await confirm_properties( + "accounts", + title, + accounts, + ) + + +async def confirm_unsupported_instruction_confirm( + instruction: Instruction, + instructions_count: int, + instruction_index: int, + signer_path: list[int], + signer_public_key: bytes, +) -> None: + formatted_instruction_id = ( + instruction.instruction_id if instruction.instruction_id is not None else "N/A" + ) + title = f"{instruction_index}/{instructions_count}: {instruction.ui_name}: instruction id ({formatted_instruction_id})" + + return await confirm_unsupported_instruction_details( + instruction, title, signer_path, signer_public_key + ) + + +async def confirm_unsupported_program_confirm( + instruction: Instruction, + instructions_count: int, + instruction_index: int, + signer_path: list[int], + signer_public_key: bytes, +) -> None: + title = f"{instruction_index}/{instructions_count}: {instruction.ui_name}" + + return await confirm_unsupported_instruction_details( + instruction, title, signer_path, signer_public_key + ) + + +async def confirm_system_transfer( + transfer_instruction: SystemProgramTransferInstruction, + fee: int, + signer_path: list[int], + blockhash: bytes, +) -> None: + await confirm_value( + title="Recipient", + value=base58.encode(transfer_instruction.recipient_account[0]), + description="", + br_type="confirm_recipient", + br_code=ButtonRequestType.ConfirmOutput, + verb="CONTINUE", + ) + + await confirm_custom_transaction( + transfer_instruction.lamports, + 9, + "SOL", + fee, + signer_path, + blockhash, + ) + + +async def confirm_token_transfer( + create_token_account_instruction: AssociatedTokenAccountProgramCreateInstruction + | None, + transfer_token_instruction: TokenProgramTransferCheckedInstruction, + fee: int, + signer_path: list[int], + blockhash: bytes, +): + recipient_props: list[PropertyType] = [ + ("", base58.encode(transfer_token_instruction.destination_account[0])) + ] + if create_token_account_instruction is not None: + recipient_props.append(("(account will be created)", "")) + + await confirm_properties( + "confirm_recipient", + "Recipient", + recipient_props, + ) + + await confirm_value( + title="Token address", + value=base58.encode(transfer_token_instruction.token_mint[0]), + description="", + br_type="confirm_token_address", + br_code=ButtonRequestType.ConfirmOutput, + verb="CONTINUE", + ) + + await confirm_custom_transaction( + transfer_token_instruction.amount, + transfer_token_instruction.decimals, + "[TOKEN]", + fee, + signer_path, + blockhash, + ) + + +async def confirm_custom_transaction( + amount: int, + decimals: int, + unit: str, + fee: int, + signer_path: list[int], + blockhash: bytes, +) -> None: + await confirm_solana_tx( + amount=f"{format_amount(amount, decimals)} {unit}", + fee=f"{format_amount(fee, 9)} SOL", + fee_title="Expected fee:", + items=( + ("Account:", _format_path(signer_path)), + ("Blockhash:", base58.encode(blockhash)), + ), + ) + + +async def confirm_transaction( + signer_path: list[int], blockhash: bytes, fee: int +) -> None: + await confirm_solana_tx( + amount="", + amount_title="", + fee=f"{format_amount(fee, 9)} SOL", + fee_title="Expected fee:", + items=( + ("Account:", _format_path(signer_path)), + ("Blockhash:", base58.encode(blockhash)), + ), + ) diff --git a/core/src/apps/solana/sign_tx.py b/core/src/apps/solana/sign_tx.py index 9f022b2a0f..8bf8faf978 100644 --- a/core/src/apps/solana/sign_tx.py +++ b/core/src/apps/solana/sign_tx.py @@ -1,8 +1,11 @@ from typing import TYPE_CHECKING +from trezor.wire import DataError + from apps.common.keychain import with_slip44_keychain from . import CURVE, PATTERNS, SLIP44_ID +from .transaction import Transaction if TYPE_CHECKING: from trezor.messages import SolanaSignTx, SolanaTxSignature @@ -16,12 +19,178 @@ async def sign_tx( keychain: Keychain, ) -> SolanaTxSignature: from trezor.crypto.curve import ed25519 + from trezor.enums import ButtonRequestType from trezor.messages import SolanaTxSignature + from trezor.ui.layouts import confirm_metadata, show_warning - address_n = msg.address_n - serialized_tx = msg.serialized_tx + from apps.common import seed + + from .layout import confirm_transaction + + address_n = msg.address_n # local_cache_attribute + serialized_tx = msg.serialized_tx # local_cache_attribute node = keychain.derive(address_n) + signer_public_key = seed.remove_ed25519_prefix(node.public_key()) + + try: + transaction: Transaction = Transaction(serialized_tx) + except Exception: + raise DataError("Invalid transaction") + + if transaction.blind_signing: + await show_warning( + "warning_blind_signing", "Transaction contains unknown instructions." + ) + + if transaction.required_signers_count > 1: + await confirm_metadata( + "multiple_signers", + "Multiple signers", + f"Transaction requires {transaction.required_signers_count} signers which increases the fee.", + br_code=ButtonRequestType.Other, + ) + + fee = calculate_fee(transaction) + + if not await try_confirm_predefined_transaction( + transaction, fee, address_n, transaction.blockhash + ): + await confirm_instructions(address_n, signer_public_key, transaction) + await confirm_transaction( + address_n, + transaction.blockhash, + calculate_fee(transaction), + ) + signature = ed25519.sign(node.private_key(), serialized_tx) return SolanaTxSignature(signature=signature) + + +async def try_confirm_predefined_transaction( + transaction: Transaction, fee: int, signer_path: list[int], blockhash: bytes +) -> bool: + from .layout import confirm_system_transfer, confirm_token_transfer + from .transaction.instructions import ( + AssociatedTokenAccountProgramCreateInstruction, + SystemProgramTransferInstruction, + TokenProgramTransferCheckedInstruction, + ) + + instructions = transaction.instructions + instructions_count = len(instructions) + + if instructions_count == 1: + if SystemProgramTransferInstruction.is_type_of(instructions[0]): + await confirm_system_transfer(instructions[0], fee, signer_path, blockhash) + return True + + if TokenProgramTransferCheckedInstruction.is_type_of(instructions[0]): + await confirm_token_transfer( + None, instructions[0], fee, signer_path, blockhash + ) + return True + elif instructions_count == 2: + if AssociatedTokenAccountProgramCreateInstruction.is_type_of( + instructions[0] + ) and TokenProgramTransferCheckedInstruction.is_type_of(instructions[1]): + create_token_account_instruction = instructions[0] + transfer_token_instruction = instructions[1] + + # If the account being created is different from the recipient account we need + # to display all the instruction information. + if ( + create_token_account_instruction.associated_token_account[0] + != transfer_token_instruction.destination_account[0] + ): + return False + + await confirm_token_transfer( + instructions[0], + instructions[1], + fee, + signer_path, + blockhash, + ) + return True + + return False + + +async def confirm_instructions( + signer_path: list[int], signer_public_key: bytes, transaction: Transaction +) -> None: + instructions_count = len(transaction.instructions) + for instruction_index, instruction in enumerate(transaction.instructions, 1): + if not instruction.is_program_supported: + from .layout import confirm_unsupported_program_confirm + + await confirm_unsupported_program_confirm( + instruction, + instructions_count, + instruction_index, + signer_path, + signer_public_key, + ) + elif not instruction.is_instruction_supported: + from .layout import confirm_unsupported_instruction_confirm + + await confirm_unsupported_instruction_confirm( + instruction, + instructions_count, + instruction_index, + signer_path, + signer_public_key, + ) + else: + from .layout import confirm_instruction + + await confirm_instruction( + instruction, + instructions_count, + instruction_index, + signer_path, + signer_public_key, + ) + + +def calculate_fee(transaction: Transaction) -> int: + from .constants import SOLANA_BASE_FEE_LAMPORTS, SOLANA_COMPUTE_UNIT_LIMIT + from .transaction.instructions import ( + COMPUTE_BUDGET_PROGRAM_ID, + COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT, + COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE, + ) + from .types import AddressType + + number_of_signers = 0 + for address in transaction.addresses: + if address[1] == AddressType.AddressSig: + number_of_signers += 1 + + base_fee = SOLANA_BASE_FEE_LAMPORTS * number_of_signers + + unit_price = 0 + is_unit_price_set = False + unit_limit = SOLANA_COMPUTE_UNIT_LIMIT + is_unit_limit_set = False + + for instruction in transaction.instructions[:3]: + if instruction.program_id == COMPUTE_BUDGET_PROGRAM_ID: + if ( + instruction.instruction_id + == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_LIMIT + and not is_unit_limit_set + ): + unit_limit = instruction.units + is_unit_limit_set = True + elif ( + instruction.instruction_id + == COMPUTE_BUDGET_PROGRAM_ID_INS_SET_COMPUTE_UNIT_PRICE + and not is_unit_price_set + ): + unit_price = instruction.lamports + is_unit_price_set = True + + return int(base_fee + unit_price * unit_limit / 1000000) diff --git a/core/src/apps/solana/transaction/__init__.py b/core/src/apps/solana/transaction/__init__.py new file mode 100644 index 0000000000..932a1176ad --- /dev/null +++ b/core/src/apps/solana/transaction/__init__.py @@ -0,0 +1,204 @@ +from typing import TYPE_CHECKING + +from trezor.crypto import base58 +from trezor.utils import BufferReader +from trezor.wire import DataError + +from ..types import AddressType +from .instruction import Instruction +from .instructions import get_instruction, get_instruction_id_length +from .parse import parse_block_hash, parse_pubkey, parse_var_int + +if TYPE_CHECKING: + from ..types import Account, Address, AddressReference, RawInstruction + + +class Transaction: + blind_signing = False + required_signers_count = 0 + + version: int | None = None + + addresses: list[Address] + + blockhash: bytes + + raw_instructions: list[RawInstruction] + instructions: list[Instruction] + + address_lookup_tables_rw_addresses: list[AddressReference] + address_lookup_tables_ro_addresses: list[AddressReference] + + def __init__(self, serialized_tx: bytes) -> None: + self._parse_transaction(serialized_tx) + self._create_instructions() + self._determine_if_blind_signing() + + def _parse_transaction(self, serialized_tx: bytes) -> None: + serialized_tx_reader = BufferReader(serialized_tx) + self._parse_header(serialized_tx_reader) + + self._parse_addresses(serialized_tx_reader) + + self.blockhash = parse_block_hash(serialized_tx_reader) + + self._parse_instructions(serialized_tx_reader) + + self._parse_address_lookup_tables(serialized_tx_reader) + + if serialized_tx_reader.remaining_count() != 0: + raise DataError("Invalid transaction") + + def _parse_header(self, serialized_tx_reader: BufferReader) -> None: + self.version: int | None = None + + if serialized_tx_reader.peek() & 0b10000000: + self.version = serialized_tx_reader.get() & 0b01111111 + # only version 0 is supported + if self.version > 0: + raise DataError("Unsupported transaction version") + + self.required_signers_count: int = serialized_tx_reader.get() + self.num_signature_read_only_addresses: int = serialized_tx_reader.get() + self.num_read_only_addresses: int = serialized_tx_reader.get() + + def _parse_addresses(self, serialized_tx_reader: BufferReader) -> None: + num_of_addresses = parse_var_int(serialized_tx_reader) + + assert ( + num_of_addresses + >= self.required_signers_count + + self.num_signature_read_only_addresses + + self.num_read_only_addresses + ) + + addresses: list[Address] = [] + for i in range(num_of_addresses): + if i < self.required_signers_count: + type = AddressType.AddressSig + elif ( + i < self.required_signers_count + self.num_signature_read_only_addresses + ): + type = AddressType.AddressSigReadOnly + elif ( + i + < self.required_signers_count + + self.num_signature_read_only_addresses + + self.num_read_only_addresses + ): + type = AddressType.AddressRw + else: + type = AddressType.AddressReadOnly + + address = parse_pubkey(serialized_tx_reader) + + addresses.append((address, type)) + + self.addresses = addresses + + def _parse_instructions(self, serialized_tx_reader: BufferReader) -> None: + num_of_instructions = parse_var_int(serialized_tx_reader) + + self.raw_instructions = [] + + for _ in range(num_of_instructions): + program_index = serialized_tx_reader.get() + program_id = base58.encode(self.addresses[program_index][0]) + num_of_accounts = parse_var_int(serialized_tx_reader) + accounts: list[int] = [] + for _ in range(num_of_accounts): + account_index = serialized_tx_reader.get() + accounts.append(account_index) + + data_length = parse_var_int(serialized_tx_reader) + + instruction_id_length = get_instruction_id_length(program_id) + if instruction_id_length <= data_length: + instruction_id = int.from_bytes( + serialized_tx_reader.read_memoryview(instruction_id_length), + "little", + ) + else: + instruction_id = None + + instruction_data = serialized_tx_reader.read_memoryview( + max(0, data_length - instruction_id_length) + ) + + self.raw_instructions.append( + (program_index, instruction_id, accounts, instruction_data) + ) + + def _parse_address_lookup_tables(self, serialized_tx: BufferReader) -> None: + self.address_lookup_tables_rw_addresses = [] + self.address_lookup_tables_ro_addresses = [] + + if self.version is None: + return + + address_lookup_tables_count = parse_var_int(serialized_tx) + for _ in range(address_lookup_tables_count): + account = parse_pubkey(serialized_tx) + + table_rw_indexes_count = parse_var_int(serialized_tx) + for _ in range(table_rw_indexes_count): + index = serialized_tx.get() + self.address_lookup_tables_rw_addresses.append( + (account, index, AddressType.AddressRw) + ) + + table_ro_indexes_count = parse_var_int(serialized_tx) + for _ in range(table_ro_indexes_count): + index = serialized_tx.get() + self.address_lookup_tables_ro_addresses.append( + (account, index, AddressType.AddressReadOnly) + ) + + def _get_combined_accounts(self) -> list[Account]: + accounts: list[Account] = [] + for address in self.addresses: + accounts.append(address) + + for rw_address in self.address_lookup_tables_rw_addresses: + accounts.append(rw_address) + for ro_address in self.address_lookup_tables_ro_addresses: + accounts.append(ro_address) + + return accounts + + def _create_instructions(self) -> None: + # Instructions reference accounts by index in this combined list. + combined_accounts = ( + self.addresses # type: ignore [Operator "+" not supported for types "list[Address]" and "list[AddressReference]"] + + self.address_lookup_tables_rw_addresses + + self.address_lookup_tables_ro_addresses + ) + + self.instructions = [] + for ( + program_index, + instruction_id, + accounts, + instruction_data, + ) in self.raw_instructions: + program_id = base58.encode(self.addresses[program_index][0]) + instruction_accounts = [ + combined_accounts[account_index] for account_index in accounts + ] + instruction = get_instruction( + program_id, + instruction_id, + instruction_accounts, + instruction_data, + ) + + self.instructions.append(instruction) + + def _determine_if_blind_signing(self) -> None: + for instruction in self.instructions: + if ( + not instruction.is_program_supported + or not instruction.is_instruction_supported + ): + self.blind_signing = True + break diff --git a/core/src/apps/solana/transaction/instruction.py b/core/src/apps/solana/transaction/instruction.py new file mode 100644 index 0000000000..77320f8975 --- /dev/null +++ b/core/src/apps/solana/transaction/instruction.py @@ -0,0 +1,154 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any, TypeGuard + + from typing_extensions import Self + + from ..types import ( + Account, + AccountTemplate, + InstructionData, + PropertyTemplate, + UIProperty, + ) + + +class Instruction: + program_id: str + instruction_id: int | None + + property_templates: list[PropertyTemplate] + accounts_template: list[AccountTemplate] + + ui_name: str + + ui_properties: list[UIProperty] + + parsed_data: dict[str, Any] + parsed_accounts: dict[str, Account] + + is_program_supported: bool + is_instruction_supported: bool + instruction_data: InstructionData + accounts: list[Account] + + multisig_signers: list[Account] + + is_deprecated_warning: str | None = None + + @staticmethod + def parse_instruction_data( + instruction_data: InstructionData, property_templates: list[PropertyTemplate] + ): + from trezor.utils import BufferReader + from trezor.wire import DataError + + reader = BufferReader(instruction_data) + + parsed_data = {} + for property_template in property_templates: + is_included = True + if property_template.is_optional: + is_included = True if reader.get() == 1 else False + + parsed_data[property_template.name] = ( + property_template.parse(reader) if is_included else None + ) + + if reader.remaining_count() != 0: + raise DataError("Invalid transaction data") + + return parsed_data + + @staticmethod + def parse_instruction_accounts( + accounts: list[Account], accounts_template: list[AccountTemplate] + ): + parsed_account = {} + for i, account_template in enumerate(accounts_template): + if i >= len(accounts): + if account_template.optional: + continue + else: + raise ValueError # "Account is missing + + parsed_account[account_template.name] = accounts[i] + return parsed_account + + def __init__( + self, + instruction_data: InstructionData, + program_id: str, + accounts: list[Account], + instruction_id: int | None, + property_templates: list[PropertyTemplate], + accounts_template: list[AccountTemplate], + ui_properties: list[UIProperty], + ui_name: str, + is_program_supported: bool = True, + is_instruction_supported: bool = True, + supports_multisig: bool = False, + is_deprecated_warning: str | None = None, + ) -> None: + self.program_id = program_id + self.instruction_id = instruction_id + + self.property_templates = property_templates + self.accounts_template = accounts_template + + self.ui_name = ui_name + + self.ui_properties = ui_properties + + self.is_program_supported = is_program_supported + self.is_instruction_supported = is_instruction_supported + + self.is_deprecated_warning = is_deprecated_warning + + self.instruction_data = instruction_data + self.accounts = accounts + + if self.is_instruction_supported: + self.parsed_data = self.parse_instruction_data( + instruction_data, property_templates + ) + + self.parsed_accounts = self.parse_instruction_accounts( + accounts, accounts_template + ) + + self.multisig_signers = accounts[len(accounts_template) :] + if self.multisig_signers and not supports_multisig: + raise ValueError # Multisig not supported + else: + self.parsed_data = {} + self.parsed_accounts = {} + self.multisig_signers = [] + + def __getattr__(self, attr: str) -> Any: + if attr in self.parsed_data: + return self.parsed_data[attr] + if attr in self.parsed_accounts: + return self.parsed_accounts[attr] + + raise AttributeError # Attribute not found + + def get_property_template(self, property: str) -> PropertyTemplate: + for property_template in self.property_templates: + if property_template.name == property: + return property_template + + raise ValueError # Property not found + + def get_account_template(self, account_name: str) -> AccountTemplate: + for account_template in self.accounts_template: + if account_template.name == account_name: + return account_template + + raise ValueError # Account not found + + @classmethod + def is_type_of(cls, ins: Any) -> TypeGuard[Self]: + # gets overridden in `instructions.py` `FakeClass` + raise NotImplementedError diff --git a/core/src/apps/solana/transaction/parse.py b/core/src/apps/solana/transaction/parse.py new file mode 100644 index 0000000000..12a43b17de --- /dev/null +++ b/core/src/apps/solana/transaction/parse.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from apps.common.readers import read_uint64_le + +if TYPE_CHECKING: + from trezor.utils import BufferReader + + +def parse_var_int(serialized_tx: BufferReader) -> int: + value = 0 + shift = 0 + while serialized_tx.remaining_count(): + B = serialized_tx.get() + value += (B & 0b01111111) << shift + shift += 7 + if B & 0b10000000 == 0: + break + + if value > 0xFFFF: + raise ValueError # compact-u16 value too large + + return value + + +def parse_block_hash(serialized_tx: BufferReader) -> bytes: + return bytes(serialized_tx.read_memoryview(32)) + + +def parse_pubkey(serialized_tx: BufferReader) -> bytes: + return bytes(serialized_tx.read_memoryview(32)) + + +def parse_enum(serialized_tx: BufferReader) -> int: + return serialized_tx.get() + + +def parse_string(serialized_tx: BufferReader) -> str: + # TODO SOL: validation shall be checked (length is less than 2^32 or even less) + length = read_uint64_le(serialized_tx) + return bytes(serialized_tx.read_memoryview(length)).decode("utf-8") + + +def parse_memo(serialized_tx: BufferReader) -> str: + return bytes(serialized_tx.read_memoryview(serialized_tx.remaining_count())).decode( + "utf-8" + ) + + +def parse_byte(serialized_tx: BufferReader) -> int: + return serialized_tx.get() diff --git a/core/src/apps/solana/types.py b/core/src/apps/solana/types.py new file mode 100644 index 0000000000..2c630dbcf6 --- /dev/null +++ b/core/src/apps/solana/types.py @@ -0,0 +1,73 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from enum import IntEnum + from typing import Any, Callable, Generic, TypeVar + + from trezor.utils import BufferReader + + from .transaction import Instruction + + Address = tuple[bytes, "AddressType"] + AddressReference = tuple[bytes, int, "AddressType"] + Account = Address | AddressReference + + ProgramIndex = int + InstructionId = int | None + AccountIndex = int + InstructionData = memoryview + RawInstruction = tuple[ + ProgramIndex, InstructionId, list[AccountIndex], InstructionData + ] + + T = TypeVar("T") +else: + IntEnum = object + T = 0 + Generic = {T: object} + + +class AddressType(IntEnum): + AddressSig = 0 + AddressSigReadOnly = 1 + AddressReadOnly = 2 + AddressRw = 3 + + +class PropertyTemplate(Generic[T]): + def __init__( + self, + name: str, + is_authority: bool, + is_optional: bool, + parse: Callable[[BufferReader], T], + format: Callable[[Instruction, T], str], + ): + self.name = name + self.is_authority = is_authority + self.is_optional = is_optional + self.parse = parse + self.format = format + + +class AccountTemplate: + def __init__(self, name: str, is_authority: bool, optional: bool): + self.name = name + self.is_authority = is_authority + self.optional = optional + + +class UIProperty: + def __init__( + self, + parameter: str | None, + account: str | None, + display_name: str, + is_authority: bool, + default_value_to_hide: Any | None, + ) -> None: + self.parameter = parameter + self.account = account + self.display_name = display_name + self.is_authority = is_authority + self.default_value_to_hide = default_value_to_hide