From f3ae77e1afc02a031b83af68b16c48ead9d28902 Mon Sep 17 00:00:00 2001 From: gabrielkerekes Date: Fri, 4 Aug 2023 13:35:24 +0200 Subject: [PATCH] instructions as classes --- core/src/all_modules.py | 20 +- core/src/apps/solana/instructions/__init__.py | 23 +- .../apps/solana/instructions/stake_program.py | 102 ++++-- .../solana/instructions/system_program.py | 321 ++++++++---------- core/src/apps/solana/parsing/parse.py | 4 +- core/src/apps/solana/sign_tx.py | 2 +- core/src/apps/solana/types.py | 2 +- 7 files changed, 244 insertions(+), 230 deletions(-) diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 3800766ec5..3689965490 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -365,14 +365,30 @@ apps.misc.sign_identity import apps.misc.sign_identity apps.solana import apps.solana +apps.solana.constants +import apps.solana.constants apps.solana.get_address import apps.solana.get_address apps.solana.get_public_key import apps.solana.get_public_key -apps.solana.helpers.paths -import apps.solana.helpers.paths +apps.solana.instructions +import apps.solana.instructions +apps.solana.instructions.stake_program +import apps.solana.instructions.stake_program +apps.solana.instructions.system_program +import apps.solana.instructions.system_program +apps.solana.parsing +import apps.solana.parsing +apps.solana.parsing.parse +import apps.solana.parsing.parse +apps.solana.parsing.parse_instructions +import apps.solana.parsing.parse_instructions +apps.solana.parsing.utils +import apps.solana.parsing.utils apps.solana.sign_tx import apps.solana.sign_tx +apps.solana.types +import apps.solana.types apps.workflow_handlers import apps.workflow_handlers diff --git a/core/src/apps/solana/instructions/__init__.py b/core/src/apps/solana/instructions/__init__.py index 3455bf69c1..53ab835a08 100644 --- a/core/src/apps/solana/instructions/__init__.py +++ b/core/src/apps/solana/instructions/__init__.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ..types import Instruction + from ..types import Address, Data, RawInstruction SYSTEM_PROGRAM_ID = "11111111111111111111111111111111" STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111" @@ -9,8 +9,27 @@ STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111" SYSTEM_TRANSFER_ID = 2 +class Instruction: + program_id: bytes + accounts: list[Address] + data: Data + + def __init__(self, raw_instruction: RawInstruction): + self.program_id, self.accounts, self.data = raw_instruction + + def parse(self) -> None: + pass + + def validate(self, signer_pub_key: bytes) -> None: + pass + + async def show(self) -> None: + # TODO SOL: blind signing could be here? + pass + + async def handle_instructions( - instructions: list[Instruction], signer_pub_key: bytes + instructions: list[RawInstruction], signer_pub_key: bytes ) -> None: from trezor.crypto import base58 from trezor.wire import ProcessError diff --git a/core/src/apps/solana/instructions/stake_program.py b/core/src/apps/solana/instructions/stake_program.py index fd204b0bc4..3e569e7914 100644 --- a/core/src/apps/solana/instructions/stake_program.py +++ b/core/src/apps/solana/instructions/stake_program.py @@ -5,68 +5,96 @@ from trezor.ui.layouts import confirm_properties from trezor.wire import ProcessError from ..constants import ADDRESS_READ_ONLY, ADDRESS_RW -from . import STAKE_PROGRAM_ID +from . import STAKE_PROGRAM_ID, Instruction if TYPE_CHECKING: from typing import Awaitable - from ..types import Instruction + from ..types import RawInstruction INS_INITIALIZE_STAKE = 0 def handle_stake_program_instruction( - instruction: Instruction, signer_pub_key: bytes + raw_instruction: RawInstruction, signer_pub_key: bytes ) -> Awaitable[None]: - program_id, _, data = instruction + program_id, _, data = raw_instruction assert base58.encode(program_id) == STAKE_PROGRAM_ID assert data.remaining_count() >= 4 + instruction = _get_instruction(raw_instruction) + + instruction.parse() + instruction.validate(signer_pub_key) + return instruction.show() + + +def _get_instruction(raw_instruction: RawInstruction) -> Instruction: + _, _, data = raw_instruction + instruction_id = int.from_bytes(data.read(4), "little") data.seek(0) if instruction_id == INS_INITIALIZE_STAKE: - return _handle_initialize_stake(instruction, signer_pub_key) + return InitializeStakeInstruction(raw_instruction) else: # TODO SOL: blind signing - raise ProcessError("Unknown stake program instruction") + raise ProcessError("Unknown system program instruction") -def _handle_initialize_stake(instruction: Instruction, signer_address: bytes): - _, accounts, data = instruction +class InitializeStakeInstruction(Instruction): + PROGRAM_ID = STAKE_PROGRAM_ID + INSTRUCTION_ID = INS_INITIALIZE_STAKE - assert data.remaining_count() == 116 - assert len(accounts) == 2 + staker: bytes + withdrawer: bytes + unix_timestamp: int + epoch: int + custodian: bytes + uninitialized_stake_account: bytes + rent_sysvar: bytes - instruction_id = int.from_bytes(data.read(4), "little") - assert instruction_id == INS_INITIALIZE_STAKE + def parse(self) -> None: + assert self.data.remaining_count() == 116 + assert len(self.accounts) == 2 - # TODO SOL: validate staker, withdrawer, custodian - staker = data.read(32) - withdrawer = data.read(32) - # TODO SOL: should be signed int but from_bytes doesn't take the third arg - unix_timestamp = int.from_bytes(data.read(8), "little") - epoch = int.from_bytes(data.read(8), "little") - custodian = data.read(32) + instruction_id = int.from_bytes(self.data.read(4), "little") + assert instruction_id == INS_INITIALIZE_STAKE - uninitialized_stake_account, uninitialized_stake_account_type = accounts[0] - assert uninitialized_stake_account_type == ADDRESS_RW + # TODO SOL: validate staker, withdrawer, custodian + self.staker = self.data.read(32) + self.withdrawer = self.data.read(32) + # TODO SOL: should be signed int but from_bytes doesn't take the third arg + self.unix_timestamp = int.from_bytes(self.data.read(8), "little") + self.epoch = int.from_bytes(self.data.read(8), "little") + self.custodian = self.data.read(32) - rent_sysvar, rent_sysvar_type = accounts[1] - assert rent_sysvar_type == ADDRESS_READ_ONLY - - return confirm_properties( - "initialize_stake", - "Initialize Stake", ( - ("Staker", base58.encode(staker)), - ("Withdrawer", base58.encode(withdrawer)), - ("Unix Timestamp", str(unix_timestamp)), - ("Epoch", str(epoch)), - ("Custodian", base58.encode(custodian)), - ("Stake Account", base58.encode(uninitialized_stake_account)), - # TODO SOL: probably doesn't need to be displayed - ("Rent Sysvar", base58.encode(rent_sysvar)), - ), - ) + self.uninitialized_stake_account, + uninitialized_stake_account_type, + ) = self.accounts[0] + assert uninitialized_stake_account_type == ADDRESS_RW + + self.rent_sysvar, rent_sysvar_type = self.accounts[1] + assert rent_sysvar_type == ADDRESS_READ_ONLY + + def validate(self, signer_pub_key: bytes) -> None: + # TODO SOL: validation + pass + + def show(self) -> Awaitable[None]: + return confirm_properties( + "initialize_stake", + "Initialize Stake", + ( + ("Staker", base58.encode(self.staker)), + ("Withdrawer", base58.encode(self.withdrawer)), + ("Unix Timestamp", str(self.unix_timestamp)), + ("Epoch", str(self.epoch)), + ("Custodian", base58.encode(self.custodian)), + ("Stake Account", base58.encode(self.uninitialized_stake_account)), + # TODO SOL: probably doesn't need to be displayed + ("Rent Sysvar", base58.encode(self.rent_sysvar)), + ), + ) diff --git a/core/src/apps/solana/instructions/system_program.py b/core/src/apps/solana/instructions/system_program.py index d9c731db28..5ba96f7769 100644 --- a/core/src/apps/solana/instructions/system_program.py +++ b/core/src/apps/solana/instructions/system_program.py @@ -8,12 +8,12 @@ from trezor.wire import ProcessError from ..constants import ADDRESS_RW, ADDRESS_SIG, ADDRESS_SIG_READ_ONLY from ..parsing.utils import read_string -from . import SYSTEM_PROGRAM_ID +from . import SYSTEM_PROGRAM_ID, Instruction if TYPE_CHECKING: from typing import Awaitable - from ..types import Instruction + from ..types import RawInstruction INS_CREATE_ACCOUNT = 0 INS_TRANSFER = 2 @@ -21,221 +21,172 @@ INS_CREATE_ACCOUNT_WITH_SEED = 3 def handle_system_program_instruction( - instruction: Instruction, signer_pub_key: bytes + raw_instruction: RawInstruction, signer_pub_key: bytes ) -> Awaitable[None]: - program_id, _, data = instruction + program_id, _, data = raw_instruction assert base58.encode(program_id) == SYSTEM_PROGRAM_ID assert data.remaining_count() >= 4 + instruction = _get_instruction(raw_instruction) + + instruction.parse() + instruction.validate(signer_pub_key) + return instruction.show() + + +def _get_instruction(raw_instruction: RawInstruction) -> Instruction: + _, _, data = raw_instruction + instruction_id = int.from_bytes(data.read(4), "little") data.seek(0) if instruction_id == INS_CREATE_ACCOUNT: - return _handle_create_account(instruction, signer_pub_key) - if instruction_id == INS_TRANSFER: - return _handle_transfer(instruction, signer_pub_key) - if instruction_id == INS_CREATE_ACCOUNT_WITH_SEED: - return _handle_create_account_with_seed(instruction, signer_pub_key) + return CreateAccountInstruction(raw_instruction) + elif instruction_id == INS_TRANSFER: + return TransferInstruction(raw_instruction) + elif instruction_id == INS_CREATE_ACCOUNT_WITH_SEED: + return CreateAccountWithSeedInstruction(raw_instruction) else: # TODO SOL: blind signing raise ProcessError("Unknown system program instruction") -def _handle_create_account( - instruction: Instruction, signer_pub_key: bytes -) -> Awaitable[None]: - lamports, space, owner, funding_account, new_account = _parse_create_account( - instruction - ) - _validate_create_account(funding_account, signer_pub_key) - return _show_create_account(lamports, space, owner, funding_account, new_account) +class CreateAccountInstruction(Instruction): + PROGRAM_ID = SYSTEM_PROGRAM_ID + INSTRUCTION_ID = INS_CREATE_ACCOUNT + + lamports: int + space: int + owner: bytes + funding_account: bytes + created_account: bytes + + def parse(self) -> None: + assert self.data.remaining_count() == 52 + assert len(self.accounts) == 2 + + instruction_id = int.from_bytes(self.data.read(4), "little") + assert instruction_id == INS_CREATE_ACCOUNT + + self.lamports = int.from_bytes(self.data.read(8), "little") + self.space = int.from_bytes(self.data.read(8), "little") + self.owner = self.data.read(32) + + self.funding_account, funding_account_type = self.accounts[0] + assert funding_account_type == ADDRESS_SIG + + self.new_account, new_account_type = self.accounts[1] + assert new_account_type == ADDRESS_RW + + def validate(self, signer_pub_key: bytes) -> None: + if self.funding_account != signer_pub_key: + raise ProcessError("Invalid funding account") + + def show(self) -> Awaitable[None]: + return confirm_properties( + "create_account", + "Create Account", + ( + ("Lamports", str(self.lamports)), + ("Space", str(self.space)), + ("Owner", base58.encode(self.owner)), + ("Funding Account", base58.encode(self.funding_account)), + ("New Account", base58.encode(self.new_account)), + ), + ) -def _parse_create_account( - instruction: Instruction, -) -> tuple[int, int, bytes, bytes, bytes]: - _, accounts, data = instruction +class TransferInstruction(Instruction): + PROGRAM_ID = SYSTEM_PROGRAM_ID + INSTRUCTION_ID = INS_TRANSFER - assert data.remaining_count() == 52 - assert len(accounts) == 2 + amount: int + source: bytes + destination: bytes - instruction_id = int.from_bytes(data.read(4), "little") - assert instruction_id == INS_CREATE_ACCOUNT + def parse(self) -> None: + assert base58.encode(self.program_id) == self.PROGRAM_ID + assert self.data.remaining_count() == 12 + assert len(self.accounts) == 2 - lamports = int.from_bytes(data.read(8), "little") - space = int.from_bytes(data.read(8), "little") - owner = data.read(32) + instruction_id = int.from_bytes(self.data.read(4), "little") + assert instruction_id == self.INSTRUCTION_ID - funding_account, funding_account_type = accounts[0] - assert funding_account_type == ADDRESS_SIG + self.amount = int.from_bytes(self.data.read(8), "little") - new_account, new_account_type = accounts[1] - assert new_account_type == ADDRESS_RW + self.source, source_account_type = self.accounts[0] + assert source_account_type == ADDRESS_SIG - return lamports, space, owner, funding_account, new_account + self.destination, destination_account_type = self.accounts[1] + assert destination_account_type == ADDRESS_RW + + def validate(self, signer_pub_key: bytes) -> None: + if self.source != signer_pub_key: + raise ProcessError("Invalid source account") + + # TODO SOL: validate max amount? + + def show(self) -> Awaitable[None]: + return confirm_output( + base58.encode(self.destination), + f"{format_amount(self.amount, 8)} SOL", + br_code=ButtonRequestType.Other, + ) -def _validate_create_account(funding_account: bytes, signer_pub_key: bytes) -> None: - if funding_account != signer_pub_key: - raise ProcessError("Invalid funding account") +class CreateAccountWithSeedInstruction(Instruction): + PROGRAM_ID = SYSTEM_PROGRAM_ID + INSTRUCTION_ID = INS_CREATE_ACCOUNT_WITH_SEED + base: bytes + seed: str + lamports: int + space: int + owner: bytes + funding_account: bytes + created_account: bytes + base_account: bytes | None -def _show_create_account( - lamports: int, space: int, owner: bytes, funding_account: bytes, new_account: bytes -) -> Awaitable[None]: - return confirm_properties( - "create_account", - "Create Account", - ( - ("Lamports", str(lamports)), - ("Space", str(space)), - ("Owner", base58.encode(owner)), - ("Funding Account", base58.encode(funding_account)), - ("New Account", base58.encode(new_account)), - ), - ) + def parse(self) -> None: + assert len(self.accounts) == 2 + instruction_id = int.from_bytes(self.data.read(4), "little") + assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED -def _handle_transfer( - instruction: Instruction, signer_pub_key: bytes -) -> Awaitable[None]: - amount, source, destination = _parse_transfer(instruction) - _validate_transfer(source, signer_pub_key) - return _show_transfer(destination, amount) + self.base = self.data.read(32) + self.seed = read_string(self.data) + self.lamports = int.from_bytes(self.data.read(8), "little") + self.space = int.from_bytes(self.data.read(8), "little") + self.owner = self.data.read(32) + self.funding_account, funding_account_type = self.accounts[0] + assert funding_account_type == ADDRESS_SIG -def _parse_transfer(instruction: Instruction) -> tuple[int, bytes, bytes]: - _, accounts, data = instruction + self.created_account, created_account_type = self.accounts[1] + assert created_account_type == ADDRESS_RW - assert data.remaining_count() == 12 - assert len(accounts) == 2 + self.base_account = None + if len(self.accounts) == 3: + self.base_account, base_account_type = self.accounts[2] + assert base_account_type == ADDRESS_SIG_READ_ONLY - instruction_id = int.from_bytes(data.read(4), "little") - assert instruction_id == INS_TRANSFER + def validate(self, signer_pub_key: bytes) -> None: + if self.funding_account != signer_pub_key: + raise ProcessError("Invalid funding account") - amount = int.from_bytes(data.read(8), "little") + def show(self) -> Awaitable[None]: + props = [ + ("Base", base58.encode(self.base)), + ("Seed", self.seed), + ("Lamports", str(self.lamports)), + ("Space", str(self.space)), + ("Owner", base58.encode(self.owner)), + ("Funding Account", base58.encode(self.funding_account)), + ("Created Account", base58.encode(self.created_account)), + ] - source, source_account_type = accounts[0] - assert source_account_type == ADDRESS_SIG + if self.base_account: + props.append(("Base Account", base58.encode(self.base_account))) - destination, destination_account_type = accounts[1] - assert destination_account_type == ADDRESS_RW - - return amount, source, destination - - -def _validate_transfer(source: bytes, signer_pub_key: bytes): - if source != signer_pub_key: - raise ProcessError("Invalid source account") - - # TODO SOL: validate max amount? - - -def _show_transfer(destination: bytes, amount: int) -> Awaitable[None]: - return confirm_output( - base58.encode(destination), - f"{format_amount(amount, 8)} SOL", - br_code=ButtonRequestType.Other, - ) - - -def _handle_create_account_with_seed( - instruction: Instruction, signer_address: bytes -) -> Awaitable[None]: - ( - base, - seed, - lamports, - space, - owner, - funding_account, - created_account, - base_account, - ) = _parse_create_account_with_seed(instruction) - _validate_create_account_with_seed(funding_account, signer_address) - return _show_create_account_with_seed( - base, - seed, - lamports, - space, - owner, - funding_account, - created_account, - base_account, - ) - - -def _parse_create_account_with_seed( - instruction: Instruction, -) -> tuple[bytes, str, int, int, bytes, bytes, bytes, bytes | None]: - _, accounts, data = instruction - - # assert len(data) == 52 - assert len(accounts) == 2 - - instruction_id = int.from_bytes(data.read(4), "little") - assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED - - base = data.read(32) - seed = read_string(data) - lamports = int.from_bytes(data.read(8), "little") - space = int.from_bytes(data.read(8), "little") - owner = data.read(32) - - funding_account, funding_account_type = accounts[0] - assert funding_account_type == ADDRESS_SIG - - created_account, created_account_type = accounts[1] - assert created_account_type == ADDRESS_RW - - base_account = None - if len(accounts) == 3: - base_account, base_account_type = accounts[2] - assert base_account_type == ADDRESS_SIG_READ_ONLY - - return ( - base, - seed, - lamports, - space, - owner, - funding_account, - created_account, - base_account, - ) - - -def _validate_create_account_with_seed( - funding_account: bytes, signer_pub_key: bytes -) -> None: - # TODO SOL: pass for now since we don't have the proper mnemonic - pass - # if funding_account != signer_pub_key: - # raise ProcessError("Invalid funding account") - - -def _show_create_account_with_seed( - base: bytes, - seed: str, - lamports: int, - space: int, - owner: bytes, - funding_account: bytes, - created_account: bytes, - base_account: bytes | None, -) -> Awaitable[None]: - props = [ - ("Base", base58.encode(base)), - ("Seed", seed), - ("Lamports", str(lamports)), - ("Space", str(space)), - ("Owner", base58.encode(owner)), - ("Funding Account", base58.encode(funding_account)), - ("Created Account", base58.encode(created_account)), - ] - - if base_account: - props.append(("Base Account", base58.encode(base_account))) - - return confirm_properties("create_account", "Create Account", props) + return confirm_properties("create_account", "Create Account", props) diff --git a/core/src/apps/solana/parsing/parse.py b/core/src/apps/solana/parsing/parse.py index 25bbfe8e9e..0249e806f7 100644 --- a/core/src/apps/solana/parsing/parse.py +++ b/core/src/apps/solana/parsing/parse.py @@ -14,12 +14,12 @@ from .utils import read_compact_u16 if TYPE_CHECKING: from trezor.utils import BufferReader - from ..types import Address, Instruction + from ..types import Address, RawInstruction def parse( serialized_tx: BufferReader, -) -> tuple[list[Address], bytes, list[Instruction]]: +) -> tuple[list[Address], bytes, list[RawInstruction]]: # TODO SOL: signature parsing can be removed? # num_of_signatures = decode_length(serialized_tx) # assert num_of_signatures == 0 diff --git a/core/src/apps/solana/sign_tx.py b/core/src/apps/solana/sign_tx.py index 9c4e32e8d1..dd2061826a 100644 --- a/core/src/apps/solana/sign_tx.py +++ b/core/src/apps/solana/sign_tx.py @@ -26,7 +26,7 @@ async def sign_tx( signature = ed25519.sign(node.private_key(), serialized_tx) - addresses, blockhash, instructions = parse(BufferReader(serialized_tx)) + _, _, instructions = parse(BufferReader(serialized_tx)) signer_pub_key = seed.remove_ed25519_prefix(node.public_key()) await handle_instructions(instructions, signer_pub_key) diff --git a/core/src/apps/solana/types.py b/core/src/apps/solana/types.py index 9eb4451ceb..aef386c13c 100644 --- a/core/src/apps/solana/types.py +++ b/core/src/apps/solana/types.py @@ -6,4 +6,4 @@ if TYPE_CHECKING: Address = tuple[bytes, int] ProgramId = bytes Data = BufferReader - Instruction = tuple[ProgramId, list[Address], Data] + RawInstruction = tuple[ProgramId, list[Address], Data]