1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-26 08:29:26 +00:00

instructions as classes

This commit is contained in:
gabrielkerekes 2023-08-04 13:35:24 +02:00
parent 3851170040
commit f3ae77e1af
7 changed files with 244 additions and 230 deletions

View File

@ -365,14 +365,30 @@ apps.misc.sign_identity
import apps.misc.sign_identity import apps.misc.sign_identity
apps.solana apps.solana
import apps.solana import apps.solana
apps.solana.constants
import apps.solana.constants
apps.solana.get_address apps.solana.get_address
import apps.solana.get_address import apps.solana.get_address
apps.solana.get_public_key apps.solana.get_public_key
import apps.solana.get_public_key import apps.solana.get_public_key
apps.solana.helpers.paths apps.solana.instructions
import apps.solana.helpers.paths import apps.solana.instructions
apps.solana.instructions.stake_program
import apps.solana.instructions.stake_program
apps.solana.instructions.system_program
import apps.solana.instructions.system_program
apps.solana.parsing
import apps.solana.parsing
apps.solana.parsing.parse
import apps.solana.parsing.parse
apps.solana.parsing.parse_instructions
import apps.solana.parsing.parse_instructions
apps.solana.parsing.utils
import apps.solana.parsing.utils
apps.solana.sign_tx apps.solana.sign_tx
import apps.solana.sign_tx import apps.solana.sign_tx
apps.solana.types
import apps.solana.types
apps.workflow_handlers apps.workflow_handlers
import apps.workflow_handlers import apps.workflow_handlers

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from ..types import Instruction from ..types import Address, Data, RawInstruction
SYSTEM_PROGRAM_ID = "11111111111111111111111111111111" SYSTEM_PROGRAM_ID = "11111111111111111111111111111111"
STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111" STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111"
@ -9,8 +9,27 @@ STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111"
SYSTEM_TRANSFER_ID = 2 SYSTEM_TRANSFER_ID = 2
class Instruction:
program_id: bytes
accounts: list[Address]
data: Data
def __init__(self, raw_instruction: RawInstruction):
self.program_id, self.accounts, self.data = raw_instruction
def parse(self) -> None:
pass
def validate(self, signer_pub_key: bytes) -> None:
pass
async def show(self) -> None:
# TODO SOL: blind signing could be here?
pass
async def handle_instructions( async def handle_instructions(
instructions: list[Instruction], signer_pub_key: bytes instructions: list[RawInstruction], signer_pub_key: bytes
) -> None: ) -> None:
from trezor.crypto import base58 from trezor.crypto import base58
from trezor.wire import ProcessError from trezor.wire import ProcessError

View File

@ -5,68 +5,96 @@ from trezor.ui.layouts import confirm_properties
from trezor.wire import ProcessError from trezor.wire import ProcessError
from ..constants import ADDRESS_READ_ONLY, ADDRESS_RW from ..constants import ADDRESS_READ_ONLY, ADDRESS_RW
from . import STAKE_PROGRAM_ID from . import STAKE_PROGRAM_ID, Instruction
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable from typing import Awaitable
from ..types import Instruction from ..types import RawInstruction
INS_INITIALIZE_STAKE = 0 INS_INITIALIZE_STAKE = 0
def handle_stake_program_instruction( def handle_stake_program_instruction(
instruction: Instruction, signer_pub_key: bytes raw_instruction: RawInstruction, signer_pub_key: bytes
) -> Awaitable[None]: ) -> Awaitable[None]:
program_id, _, data = instruction program_id, _, data = raw_instruction
assert base58.encode(program_id) == STAKE_PROGRAM_ID assert base58.encode(program_id) == STAKE_PROGRAM_ID
assert data.remaining_count() >= 4 assert data.remaining_count() >= 4
instruction = _get_instruction(raw_instruction)
instruction.parse()
instruction.validate(signer_pub_key)
return instruction.show()
def _get_instruction(raw_instruction: RawInstruction) -> Instruction:
_, _, data = raw_instruction
instruction_id = int.from_bytes(data.read(4), "little") instruction_id = int.from_bytes(data.read(4), "little")
data.seek(0) data.seek(0)
if instruction_id == INS_INITIALIZE_STAKE: if instruction_id == INS_INITIALIZE_STAKE:
return _handle_initialize_stake(instruction, signer_pub_key) return InitializeStakeInstruction(raw_instruction)
else: else:
# TODO SOL: blind signing # TODO SOL: blind signing
raise ProcessError("Unknown stake program instruction") raise ProcessError("Unknown system program instruction")
def _handle_initialize_stake(instruction: Instruction, signer_address: bytes): class InitializeStakeInstruction(Instruction):
_, accounts, data = instruction PROGRAM_ID = STAKE_PROGRAM_ID
INSTRUCTION_ID = INS_INITIALIZE_STAKE
assert data.remaining_count() == 116 staker: bytes
assert len(accounts) == 2 withdrawer: bytes
unix_timestamp: int
epoch: int
custodian: bytes
uninitialized_stake_account: bytes
rent_sysvar: bytes
instruction_id = int.from_bytes(data.read(4), "little") def parse(self) -> None:
assert self.data.remaining_count() == 116
assert len(self.accounts) == 2
instruction_id = int.from_bytes(self.data.read(4), "little")
assert instruction_id == INS_INITIALIZE_STAKE assert instruction_id == INS_INITIALIZE_STAKE
# TODO SOL: validate staker, withdrawer, custodian # TODO SOL: validate staker, withdrawer, custodian
staker = data.read(32) self.staker = self.data.read(32)
withdrawer = data.read(32) self.withdrawer = self.data.read(32)
# TODO SOL: should be signed int but from_bytes doesn't take the third arg # TODO SOL: should be signed int but from_bytes doesn't take the third arg
unix_timestamp = int.from_bytes(data.read(8), "little") self.unix_timestamp = int.from_bytes(self.data.read(8), "little")
epoch = int.from_bytes(data.read(8), "little") self.epoch = int.from_bytes(self.data.read(8), "little")
custodian = data.read(32) self.custodian = self.data.read(32)
uninitialized_stake_account, uninitialized_stake_account_type = accounts[0] (
self.uninitialized_stake_account,
uninitialized_stake_account_type,
) = self.accounts[0]
assert uninitialized_stake_account_type == ADDRESS_RW assert uninitialized_stake_account_type == ADDRESS_RW
rent_sysvar, rent_sysvar_type = accounts[1] self.rent_sysvar, rent_sysvar_type = self.accounts[1]
assert rent_sysvar_type == ADDRESS_READ_ONLY assert rent_sysvar_type == ADDRESS_READ_ONLY
def validate(self, signer_pub_key: bytes) -> None:
# TODO SOL: validation
pass
def show(self) -> Awaitable[None]:
return confirm_properties( return confirm_properties(
"initialize_stake", "initialize_stake",
"Initialize Stake", "Initialize Stake",
( (
("Staker", base58.encode(staker)), ("Staker", base58.encode(self.staker)),
("Withdrawer", base58.encode(withdrawer)), ("Withdrawer", base58.encode(self.withdrawer)),
("Unix Timestamp", str(unix_timestamp)), ("Unix Timestamp", str(self.unix_timestamp)),
("Epoch", str(epoch)), ("Epoch", str(self.epoch)),
("Custodian", base58.encode(custodian)), ("Custodian", base58.encode(self.custodian)),
("Stake Account", base58.encode(uninitialized_stake_account)), ("Stake Account", base58.encode(self.uninitialized_stake_account)),
# TODO SOL: probably doesn't need to be displayed # TODO SOL: probably doesn't need to be displayed
("Rent Sysvar", base58.encode(rent_sysvar)), ("Rent Sysvar", base58.encode(self.rent_sysvar)),
), ),
) )

View File

@ -8,12 +8,12 @@ from trezor.wire import ProcessError
from ..constants import ADDRESS_RW, ADDRESS_SIG, ADDRESS_SIG_READ_ONLY from ..constants import ADDRESS_RW, ADDRESS_SIG, ADDRESS_SIG_READ_ONLY
from ..parsing.utils import read_string from ..parsing.utils import read_string
from . import SYSTEM_PROGRAM_ID from . import SYSTEM_PROGRAM_ID, Instruction
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable from typing import Awaitable
from ..types import Instruction from ..types import RawInstruction
INS_CREATE_ACCOUNT = 0 INS_CREATE_ACCOUNT = 0
INS_TRANSFER = 2 INS_TRANSFER = 2
@ -21,221 +21,172 @@ INS_CREATE_ACCOUNT_WITH_SEED = 3
def handle_system_program_instruction( def handle_system_program_instruction(
instruction: Instruction, signer_pub_key: bytes raw_instruction: RawInstruction, signer_pub_key: bytes
) -> Awaitable[None]: ) -> Awaitable[None]:
program_id, _, data = instruction program_id, _, data = raw_instruction
assert base58.encode(program_id) == SYSTEM_PROGRAM_ID assert base58.encode(program_id) == SYSTEM_PROGRAM_ID
assert data.remaining_count() >= 4 assert data.remaining_count() >= 4
instruction = _get_instruction(raw_instruction)
instruction.parse()
instruction.validate(signer_pub_key)
return instruction.show()
def _get_instruction(raw_instruction: RawInstruction) -> Instruction:
_, _, data = raw_instruction
instruction_id = int.from_bytes(data.read(4), "little") instruction_id = int.from_bytes(data.read(4), "little")
data.seek(0) data.seek(0)
if instruction_id == INS_CREATE_ACCOUNT: if instruction_id == INS_CREATE_ACCOUNT:
return _handle_create_account(instruction, signer_pub_key) return CreateAccountInstruction(raw_instruction)
if instruction_id == INS_TRANSFER: elif instruction_id == INS_TRANSFER:
return _handle_transfer(instruction, signer_pub_key) return TransferInstruction(raw_instruction)
if instruction_id == INS_CREATE_ACCOUNT_WITH_SEED: elif instruction_id == INS_CREATE_ACCOUNT_WITH_SEED:
return _handle_create_account_with_seed(instruction, signer_pub_key) return CreateAccountWithSeedInstruction(raw_instruction)
else: else:
# TODO SOL: blind signing # TODO SOL: blind signing
raise ProcessError("Unknown system program instruction") raise ProcessError("Unknown system program instruction")
def _handle_create_account( class CreateAccountInstruction(Instruction):
instruction: Instruction, signer_pub_key: bytes PROGRAM_ID = SYSTEM_PROGRAM_ID
) -> Awaitable[None]: INSTRUCTION_ID = INS_CREATE_ACCOUNT
lamports, space, owner, funding_account, new_account = _parse_create_account(
instruction
)
_validate_create_account(funding_account, signer_pub_key)
return _show_create_account(lamports, space, owner, funding_account, new_account)
lamports: int
space: int
owner: bytes
funding_account: bytes
created_account: bytes
def _parse_create_account( def parse(self) -> None:
instruction: Instruction, assert self.data.remaining_count() == 52
) -> tuple[int, int, bytes, bytes, bytes]: assert len(self.accounts) == 2
_, accounts, data = instruction
assert data.remaining_count() == 52 instruction_id = int.from_bytes(self.data.read(4), "little")
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_CREATE_ACCOUNT assert instruction_id == INS_CREATE_ACCOUNT
lamports = int.from_bytes(data.read(8), "little") self.lamports = int.from_bytes(self.data.read(8), "little")
space = int.from_bytes(data.read(8), "little") self.space = int.from_bytes(self.data.read(8), "little")
owner = data.read(32) self.owner = self.data.read(32)
funding_account, funding_account_type = accounts[0] self.funding_account, funding_account_type = self.accounts[0]
assert funding_account_type == ADDRESS_SIG assert funding_account_type == ADDRESS_SIG
new_account, new_account_type = accounts[1] self.new_account, new_account_type = self.accounts[1]
assert new_account_type == ADDRESS_RW assert new_account_type == ADDRESS_RW
return lamports, space, owner, funding_account, new_account def validate(self, signer_pub_key: bytes) -> None:
if self.funding_account != signer_pub_key:
def _validate_create_account(funding_account: bytes, signer_pub_key: bytes) -> None:
if funding_account != signer_pub_key:
raise ProcessError("Invalid funding account") raise ProcessError("Invalid funding account")
def show(self) -> Awaitable[None]:
def _show_create_account(
lamports: int, space: int, owner: bytes, funding_account: bytes, new_account: bytes
) -> Awaitable[None]:
return confirm_properties( return confirm_properties(
"create_account", "create_account",
"Create Account", "Create Account",
( (
("Lamports", str(lamports)), ("Lamports", str(self.lamports)),
("Space", str(space)), ("Space", str(self.space)),
("Owner", base58.encode(owner)), ("Owner", base58.encode(self.owner)),
("Funding Account", base58.encode(funding_account)), ("Funding Account", base58.encode(self.funding_account)),
("New Account", base58.encode(new_account)), ("New Account", base58.encode(self.new_account)),
), ),
) )
def _handle_transfer( class TransferInstruction(Instruction):
instruction: Instruction, signer_pub_key: bytes PROGRAM_ID = SYSTEM_PROGRAM_ID
) -> Awaitable[None]: INSTRUCTION_ID = INS_TRANSFER
amount, source, destination = _parse_transfer(instruction)
_validate_transfer(source, signer_pub_key)
return _show_transfer(destination, amount)
amount: int
source: bytes
destination: bytes
def _parse_transfer(instruction: Instruction) -> tuple[int, bytes, bytes]: def parse(self) -> None:
_, accounts, data = instruction assert base58.encode(self.program_id) == self.PROGRAM_ID
assert self.data.remaining_count() == 12
assert len(self.accounts) == 2
assert data.remaining_count() == 12 instruction_id = int.from_bytes(self.data.read(4), "little")
assert len(accounts) == 2 assert instruction_id == self.INSTRUCTION_ID
instruction_id = int.from_bytes(data.read(4), "little") self.amount = int.from_bytes(self.data.read(8), "little")
assert instruction_id == INS_TRANSFER
amount = int.from_bytes(data.read(8), "little") self.source, source_account_type = self.accounts[0]
source, source_account_type = accounts[0]
assert source_account_type == ADDRESS_SIG assert source_account_type == ADDRESS_SIG
destination, destination_account_type = accounts[1] self.destination, destination_account_type = self.accounts[1]
assert destination_account_type == ADDRESS_RW assert destination_account_type == ADDRESS_RW
return amount, source, destination def validate(self, signer_pub_key: bytes) -> None:
if self.source != signer_pub_key:
def _validate_transfer(source: bytes, signer_pub_key: bytes):
if source != signer_pub_key:
raise ProcessError("Invalid source account") raise ProcessError("Invalid source account")
# TODO SOL: validate max amount? # TODO SOL: validate max amount?
def show(self) -> Awaitable[None]:
def _show_transfer(destination: bytes, amount: int) -> Awaitable[None]:
return confirm_output( return confirm_output(
base58.encode(destination), base58.encode(self.destination),
f"{format_amount(amount, 8)} SOL", f"{format_amount(self.amount, 8)} SOL",
br_code=ButtonRequestType.Other, br_code=ButtonRequestType.Other,
) )
def _handle_create_account_with_seed( class CreateAccountWithSeedInstruction(Instruction):
instruction: Instruction, signer_address: bytes PROGRAM_ID = SYSTEM_PROGRAM_ID
) -> Awaitable[None]: INSTRUCTION_ID = INS_CREATE_ACCOUNT_WITH_SEED
(
base,
seed,
lamports,
space,
owner,
funding_account,
created_account,
base_account,
) = _parse_create_account_with_seed(instruction)
_validate_create_account_with_seed(funding_account, signer_address)
return _show_create_account_with_seed(
base,
seed,
lamports,
space,
owner,
funding_account,
created_account,
base_account,
)
base: bytes
seed: str
lamports: int
space: int
owner: bytes
funding_account: bytes
created_account: bytes
base_account: bytes | None
def _parse_create_account_with_seed( def parse(self) -> None:
instruction: Instruction, assert len(self.accounts) == 2
) -> tuple[bytes, str, int, int, bytes, bytes, bytes, bytes | None]:
_, accounts, data = instruction
# assert len(data) == 52 instruction_id = int.from_bytes(self.data.read(4), "little")
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED
base = data.read(32) self.base = self.data.read(32)
seed = read_string(data) self.seed = read_string(self.data)
lamports = int.from_bytes(data.read(8), "little") self.lamports = int.from_bytes(self.data.read(8), "little")
space = int.from_bytes(data.read(8), "little") self.space = int.from_bytes(self.data.read(8), "little")
owner = data.read(32) self.owner = self.data.read(32)
funding_account, funding_account_type = accounts[0] self.funding_account, funding_account_type = self.accounts[0]
assert funding_account_type == ADDRESS_SIG assert funding_account_type == ADDRESS_SIG
created_account, created_account_type = accounts[1] self.created_account, created_account_type = self.accounts[1]
assert created_account_type == ADDRESS_RW assert created_account_type == ADDRESS_RW
base_account = None self.base_account = None
if len(accounts) == 3: if len(self.accounts) == 3:
base_account, base_account_type = accounts[2] self.base_account, base_account_type = self.accounts[2]
assert base_account_type == ADDRESS_SIG_READ_ONLY assert base_account_type == ADDRESS_SIG_READ_ONLY
return ( def validate(self, signer_pub_key: bytes) -> None:
base, if self.funding_account != signer_pub_key:
seed, raise ProcessError("Invalid funding account")
lamports,
space,
owner,
funding_account,
created_account,
base_account,
)
def show(self) -> Awaitable[None]:
def _validate_create_account_with_seed(
funding_account: bytes, signer_pub_key: bytes
) -> None:
# TODO SOL: pass for now since we don't have the proper mnemonic
pass
# if funding_account != signer_pub_key:
# raise ProcessError("Invalid funding account")
def _show_create_account_with_seed(
base: bytes,
seed: str,
lamports: int,
space: int,
owner: bytes,
funding_account: bytes,
created_account: bytes,
base_account: bytes | None,
) -> Awaitable[None]:
props = [ props = [
("Base", base58.encode(base)), ("Base", base58.encode(self.base)),
("Seed", seed), ("Seed", self.seed),
("Lamports", str(lamports)), ("Lamports", str(self.lamports)),
("Space", str(space)), ("Space", str(self.space)),
("Owner", base58.encode(owner)), ("Owner", base58.encode(self.owner)),
("Funding Account", base58.encode(funding_account)), ("Funding Account", base58.encode(self.funding_account)),
("Created Account", base58.encode(created_account)), ("Created Account", base58.encode(self.created_account)),
] ]
if base_account: if self.base_account:
props.append(("Base Account", base58.encode(base_account))) props.append(("Base Account", base58.encode(self.base_account)))
return confirm_properties("create_account", "Create Account", props) return confirm_properties("create_account", "Create Account", props)

View File

@ -14,12 +14,12 @@ from .utils import read_compact_u16
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.utils import BufferReader from trezor.utils import BufferReader
from ..types import Address, Instruction from ..types import Address, RawInstruction
def parse( def parse(
serialized_tx: BufferReader, serialized_tx: BufferReader,
) -> tuple[list[Address], bytes, list[Instruction]]: ) -> tuple[list[Address], bytes, list[RawInstruction]]:
# TODO SOL: signature parsing can be removed? # TODO SOL: signature parsing can be removed?
# num_of_signatures = decode_length(serialized_tx) # num_of_signatures = decode_length(serialized_tx)
# assert num_of_signatures == 0 # assert num_of_signatures == 0

View File

@ -26,7 +26,7 @@ async def sign_tx(
signature = ed25519.sign(node.private_key(), serialized_tx) signature = ed25519.sign(node.private_key(), serialized_tx)
addresses, blockhash, instructions = parse(BufferReader(serialized_tx)) _, _, instructions = parse(BufferReader(serialized_tx))
signer_pub_key = seed.remove_ed25519_prefix(node.public_key()) signer_pub_key = seed.remove_ed25519_prefix(node.public_key())
await handle_instructions(instructions, signer_pub_key) await handle_instructions(instructions, signer_pub_key)

View File

@ -6,4 +6,4 @@ if TYPE_CHECKING:
Address = tuple[bytes, int] Address = tuple[bytes, int]
ProgramId = bytes ProgramId = bytes
Data = BufferReader Data = BufferReader
Instruction = tuple[ProgramId, list[Address], Data] RawInstruction = tuple[ProgramId, list[Address], Data]