diff --git a/core/src/apps/solana/constants.py b/core/src/apps/solana/constants.py new file mode 100644 index 000000000..cb5cfdf7f --- /dev/null +++ b/core/src/apps/solana/constants.py @@ -0,0 +1,8 @@ +from micropython import const + +ADDRESS_SIZE = const(32) + +ADDRESS_SIG = const(0) +ADDRESS_SIG_READ_ONLY = const(1) +ADDRESS_READ_ONLY = const(2) +ADDRESS_RW = const(3) diff --git a/core/src/apps/solana/get_public_key.py b/core/src/apps/solana/get_public_key.py index 287315c75..efc197873 100644 --- a/core/src/apps/solana/get_public_key.py +++ b/core/src/apps/solana/get_public_key.py @@ -1,19 +1,21 @@ from typing import TYPE_CHECKING from ubinascii import hexlify +from apps.common import seed from apps.common.keychain import auto_keychain if TYPE_CHECKING: from trezor.messages import SolanaGetPublicKey, SolanaPublicKey + from apps.common.keychain import Keychain + # TODO SOL: maybe only get_address is needed? @auto_keychain(__name__) async def get_public_key( - msg: SolanaGetPublicKey, keychain: seed.Keychain + msg: SolanaGetPublicKey, keychain: Keychain ) -> SolanaPublicKey: from trezor.ui.layouts import show_pubkey from trezor.messages import HDNodeType, SolanaPublicKey - from apps.common import seed node = keychain.derive(msg.address_n) @@ -26,7 +28,7 @@ async def get_public_key( ) if msg.show_display: - await show_pubkey(hexlify(node.public_key).decode()) + await show_pubkey(hexlify(node.public_key()).decode()) # TODO SOL: xpub? return SolanaPublicKey(node=node_type, xpub=node.serialize_public(0)) diff --git a/core/src/apps/solana/instructions/__init__.py b/core/src/apps/solana/instructions/__init__.py new file mode 100644 index 000000000..3455bf69c --- /dev/null +++ b/core/src/apps/solana/instructions/__init__.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..types import Instruction + +SYSTEM_PROGRAM_ID = "11111111111111111111111111111111" +STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111" + +SYSTEM_TRANSFER_ID = 2 + + +async def handle_instructions( + instructions: list[Instruction], signer_pub_key: bytes +) -> None: + from trezor.crypto import base58 + from trezor.wire import ProcessError + + from .system_program import handle_system_program_instruction + from .stake_program import handle_stake_program_instruction + + for instruction in instructions: + program_id, _, _ = instruction + + encoded_program_id = base58.encode(program_id) + + if encoded_program_id == SYSTEM_PROGRAM_ID: + await handle_system_program_instruction(instruction, signer_pub_key) + elif encoded_program_id == STAKE_PROGRAM_ID: + await handle_stake_program_instruction(instruction, signer_pub_key) + else: + # TODO SOL: blind signing for unknown programs + raise ProcessError(f"Unknown program id: {encoded_program_id}") diff --git a/core/src/apps/solana/instructions/stake_program.py b/core/src/apps/solana/instructions/stake_program.py new file mode 100644 index 000000000..fd204b0bc --- /dev/null +++ b/core/src/apps/solana/instructions/stake_program.py @@ -0,0 +1,72 @@ +from typing import TYPE_CHECKING + +from trezor.crypto import base58 +from trezor.ui.layouts import confirm_properties +from trezor.wire import ProcessError + +from ..constants import ADDRESS_READ_ONLY, ADDRESS_RW +from . import STAKE_PROGRAM_ID + +if TYPE_CHECKING: + from typing import Awaitable + + from ..types import Instruction + +INS_INITIALIZE_STAKE = 0 + + +def handle_stake_program_instruction( + instruction: Instruction, signer_pub_key: bytes +) -> Awaitable[None]: + program_id, _, data = instruction + + assert base58.encode(program_id) == STAKE_PROGRAM_ID + assert data.remaining_count() >= 4 + + instruction_id = int.from_bytes(data.read(4), "little") + data.seek(0) + + if instruction_id == INS_INITIALIZE_STAKE: + return _handle_initialize_stake(instruction, signer_pub_key) + else: + # TODO SOL: blind signing + raise ProcessError("Unknown stake program instruction") + + +def _handle_initialize_stake(instruction: Instruction, signer_address: bytes): + _, accounts, data = instruction + + assert data.remaining_count() == 116 + assert len(accounts) == 2 + + instruction_id = int.from_bytes(data.read(4), "little") + assert instruction_id == INS_INITIALIZE_STAKE + + # TODO SOL: validate staker, withdrawer, custodian + staker = data.read(32) + withdrawer = data.read(32) + # TODO SOL: should be signed int but from_bytes doesn't take the third arg + unix_timestamp = int.from_bytes(data.read(8), "little") + epoch = int.from_bytes(data.read(8), "little") + custodian = data.read(32) + + uninitialized_stake_account, uninitialized_stake_account_type = accounts[0] + assert uninitialized_stake_account_type == ADDRESS_RW + + rent_sysvar, rent_sysvar_type = accounts[1] + assert rent_sysvar_type == ADDRESS_READ_ONLY + + return confirm_properties( + "initialize_stake", + "Initialize Stake", + ( + ("Staker", base58.encode(staker)), + ("Withdrawer", base58.encode(withdrawer)), + ("Unix Timestamp", str(unix_timestamp)), + ("Epoch", str(epoch)), + ("Custodian", base58.encode(custodian)), + ("Stake Account", base58.encode(uninitialized_stake_account)), + # TODO SOL: probably doesn't need to be displayed + ("Rent Sysvar", base58.encode(rent_sysvar)), + ), + ) diff --git a/core/src/apps/solana/instructions/system_program.py b/core/src/apps/solana/instructions/system_program.py new file mode 100644 index 000000000..d9c731db2 --- /dev/null +++ b/core/src/apps/solana/instructions/system_program.py @@ -0,0 +1,241 @@ +from typing import TYPE_CHECKING + +from trezor.crypto import base58 +from trezor.enums import ButtonRequestType +from trezor.strings import format_amount +from trezor.ui.layouts import confirm_output, confirm_properties +from trezor.wire import ProcessError + +from ..constants import ADDRESS_RW, ADDRESS_SIG, ADDRESS_SIG_READ_ONLY +from ..parsing.utils import read_string +from . import SYSTEM_PROGRAM_ID + +if TYPE_CHECKING: + from typing import Awaitable + + from ..types import Instruction + +INS_CREATE_ACCOUNT = 0 +INS_TRANSFER = 2 +INS_CREATE_ACCOUNT_WITH_SEED = 3 + + +def handle_system_program_instruction( + instruction: Instruction, signer_pub_key: bytes +) -> Awaitable[None]: + program_id, _, data = instruction + + assert base58.encode(program_id) == SYSTEM_PROGRAM_ID + assert data.remaining_count() >= 4 + + instruction_id = int.from_bytes(data.read(4), "little") + data.seek(0) + + if instruction_id == INS_CREATE_ACCOUNT: + return _handle_create_account(instruction, signer_pub_key) + if instruction_id == INS_TRANSFER: + return _handle_transfer(instruction, signer_pub_key) + if instruction_id == INS_CREATE_ACCOUNT_WITH_SEED: + return _handle_create_account_with_seed(instruction, signer_pub_key) + else: + # TODO SOL: blind signing + raise ProcessError("Unknown system program instruction") + + +def _handle_create_account( + instruction: Instruction, signer_pub_key: bytes +) -> Awaitable[None]: + lamports, space, owner, funding_account, new_account = _parse_create_account( + instruction + ) + _validate_create_account(funding_account, signer_pub_key) + return _show_create_account(lamports, space, owner, funding_account, new_account) + + +def _parse_create_account( + instruction: Instruction, +) -> tuple[int, int, bytes, bytes, bytes]: + _, accounts, data = instruction + + assert data.remaining_count() == 52 + assert len(accounts) == 2 + + instruction_id = int.from_bytes(data.read(4), "little") + assert instruction_id == INS_CREATE_ACCOUNT + + lamports = int.from_bytes(data.read(8), "little") + space = int.from_bytes(data.read(8), "little") + owner = data.read(32) + + funding_account, funding_account_type = accounts[0] + assert funding_account_type == ADDRESS_SIG + + new_account, new_account_type = accounts[1] + assert new_account_type == ADDRESS_RW + + return lamports, space, owner, funding_account, new_account + + +def _validate_create_account(funding_account: bytes, signer_pub_key: bytes) -> None: + if funding_account != signer_pub_key: + raise ProcessError("Invalid funding account") + + +def _show_create_account( + lamports: int, space: int, owner: bytes, funding_account: bytes, new_account: bytes +) -> Awaitable[None]: + return confirm_properties( + "create_account", + "Create Account", + ( + ("Lamports", str(lamports)), + ("Space", str(space)), + ("Owner", base58.encode(owner)), + ("Funding Account", base58.encode(funding_account)), + ("New Account", base58.encode(new_account)), + ), + ) + + +def _handle_transfer( + instruction: Instruction, signer_pub_key: bytes +) -> Awaitable[None]: + amount, source, destination = _parse_transfer(instruction) + _validate_transfer(source, signer_pub_key) + return _show_transfer(destination, amount) + + +def _parse_transfer(instruction: Instruction) -> tuple[int, bytes, bytes]: + _, accounts, data = instruction + + assert data.remaining_count() == 12 + assert len(accounts) == 2 + + instruction_id = int.from_bytes(data.read(4), "little") + assert instruction_id == INS_TRANSFER + + amount = int.from_bytes(data.read(8), "little") + + source, source_account_type = accounts[0] + assert source_account_type == ADDRESS_SIG + + destination, destination_account_type = accounts[1] + assert destination_account_type == ADDRESS_RW + + return amount, source, destination + + +def _validate_transfer(source: bytes, signer_pub_key: bytes): + if source != signer_pub_key: + raise ProcessError("Invalid source account") + + # TODO SOL: validate max amount? + + +def _show_transfer(destination: bytes, amount: int) -> Awaitable[None]: + return confirm_output( + base58.encode(destination), + f"{format_amount(amount, 8)} SOL", + br_code=ButtonRequestType.Other, + ) + + +def _handle_create_account_with_seed( + instruction: Instruction, signer_address: bytes +) -> Awaitable[None]: + ( + base, + seed, + lamports, + space, + owner, + funding_account, + created_account, + base_account, + ) = _parse_create_account_with_seed(instruction) + _validate_create_account_with_seed(funding_account, signer_address) + return _show_create_account_with_seed( + base, + seed, + lamports, + space, + owner, + funding_account, + created_account, + base_account, + ) + + +def _parse_create_account_with_seed( + instruction: Instruction, +) -> tuple[bytes, str, int, int, bytes, bytes, bytes, bytes | None]: + _, accounts, data = instruction + + # assert len(data) == 52 + assert len(accounts) == 2 + + instruction_id = int.from_bytes(data.read(4), "little") + assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED + + base = data.read(32) + seed = read_string(data) + lamports = int.from_bytes(data.read(8), "little") + space = int.from_bytes(data.read(8), "little") + owner = data.read(32) + + funding_account, funding_account_type = accounts[0] + assert funding_account_type == ADDRESS_SIG + + created_account, created_account_type = accounts[1] + assert created_account_type == ADDRESS_RW + + base_account = None + if len(accounts) == 3: + base_account, base_account_type = accounts[2] + assert base_account_type == ADDRESS_SIG_READ_ONLY + + return ( + base, + seed, + lamports, + space, + owner, + funding_account, + created_account, + base_account, + ) + + +def _validate_create_account_with_seed( + funding_account: bytes, signer_pub_key: bytes +) -> None: + # TODO SOL: pass for now since we don't have the proper mnemonic + pass + # if funding_account != signer_pub_key: + # raise ProcessError("Invalid funding account") + + +def _show_create_account_with_seed( + base: bytes, + seed: str, + lamports: int, + space: int, + owner: bytes, + funding_account: bytes, + created_account: bytes, + base_account: bytes | None, +) -> Awaitable[None]: + props = [ + ("Base", base58.encode(base)), + ("Seed", seed), + ("Lamports", str(lamports)), + ("Space", str(space)), + ("Owner", base58.encode(owner)), + ("Funding Account", base58.encode(funding_account)), + ("Created Account", base58.encode(created_account)), + ] + + if base_account: + props.append(("Base Account", base58.encode(base_account))) + + return confirm_properties("create_account", "Create Account", props) diff --git a/core/src/apps/solana/parsing/__init__.py b/core/src/apps/solana/parsing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/src/apps/solana/parsing/parse.py b/core/src/apps/solana/parsing/parse.py new file mode 100644 index 000000000..25bbfe8e9 --- /dev/null +++ b/core/src/apps/solana/parsing/parse.py @@ -0,0 +1,101 @@ +# 01000103c80f8b50107e9f3e3c16a661b8c806df454a6deb293d5e8730a9d28f2f4998c68f41927b2e58cbc31ed3aa5163a7b8ca4eb5590e8dc1dc682426cd2895aa9c0a00000000000000000000000000000000000000000000000000000000000000001aea57c9906a7cad656ff61b3893abda63f4b6b210c939855e7ab6e54049213d01020200010c02000000002d310100000000 +from typing import TYPE_CHECKING + +from ..constants import ( + ADDRESS_READ_ONLY, + ADDRESS_RW, + ADDRESS_SIG, + ADDRESS_SIG_READ_ONLY, + ADDRESS_SIZE, +) +from .parse_instructions import parse_instructions +from .utils import read_compact_u16 + +if TYPE_CHECKING: + from trezor.utils import BufferReader + + from ..types import Address, Instruction + + +def parse( + serialized_tx: BufferReader, +) -> tuple[list[Address], bytes, list[Instruction]]: + # TODO SOL: signature parsing can be removed? + # num_of_signatures = decode_length(serialized_tx) + # assert num_of_signatures == 0 + # signatures = [] + # for i in range(num_of_signatures): + # signature = serialized_tx.read(64) + # signatures.append(signature) + + ( + num_required_signatures, + num_signature_read_only_addresses, + num_read_only_addresses, + ) = _parse_header(serialized_tx) + + addresses = _parse_addresses( + serialized_tx, + num_required_signatures, + num_signature_read_only_addresses, + num_read_only_addresses, + ) + + blockhash = bytes(serialized_tx.read(32)) + + instructions = parse_instructions(serialized_tx, addresses) + + assert serialized_tx.remaining_count() == 0 + + return addresses, blockhash, instructions + + +def _parse_header(serialized_tx: BufferReader) -> tuple[int, int, int]: + num_required_signatures = int.from_bytes(serialized_tx.read(1), "big") + num_signature_read_only_addresses = int.from_bytes(serialized_tx.read(1), "big") + num_read_only_addresses = int.from_bytes(serialized_tx.read(1), "big") + + return ( + num_required_signatures, + num_signature_read_only_addresses, + num_read_only_addresses, + ) + + +def _parse_addresses( + serialized_tx: BufferReader, + num_required_signatures: int, + num_signature_read_only_addresses: int, + num_read_only_addresses: int, +) -> list[tuple[bytes, int]]: + num_of_addresses = read_compact_u16(serialized_tx) + + assert ( + num_of_addresses + >= num_required_signatures + + num_signature_read_only_addresses + + num_read_only_addresses + ) + + addresses: list[tuple[bytes, int]] = [] + for i in range(num_of_addresses): + assert ADDRESS_SIZE <= serialized_tx.remaining_count() + + address = serialized_tx.read(ADDRESS_SIZE) + if i < num_required_signatures: + type = ADDRESS_SIG + elif i < num_required_signatures + num_signature_read_only_addresses: + type = ADDRESS_SIG_READ_ONLY + elif ( + i + < num_required_signatures + + num_signature_read_only_addresses + + num_read_only_addresses + ): + type = ADDRESS_RW + else: + type = ADDRESS_READ_ONLY + + addresses.append((address, type)) + + return addresses diff --git a/core/src/apps/solana/parsing/parse_instructions.py b/core/src/apps/solana/parsing/parse_instructions.py new file mode 100644 index 000000000..b01c937ed --- /dev/null +++ b/core/src/apps/solana/parsing/parse_instructions.py @@ -0,0 +1,35 @@ +from trezor.utils import BufferReader + +from .utils import read_compact_u16 + + +def parse_instructions( + serialized_tx: BufferReader, addresses: list[tuple[bytes, int]] +) -> list[tuple[bytes, list[tuple[bytes, int]], BufferReader]]: + num_of_instructions = read_compact_u16(serialized_tx) + + instructions: list[tuple[bytes, list[tuple[bytes, int]], BufferReader]] = [] + for _ in range(num_of_instructions): + program_index = serialized_tx.get() + assert program_index < len(addresses) + + program_id = addresses[program_index][0] + + num_of_accounts = read_compact_u16(serialized_tx) + + instruction_accounts: list[tuple[bytes, int]] = [] + for _ in range(num_of_accounts): + assert serialized_tx.remaining_count() > 0 + account_index = serialized_tx.get() + assert account_index < len(addresses) + + account = addresses[account_index] + + instruction_accounts.append(account) + + data_length = read_compact_u16(serialized_tx) + data = BufferReader(serialized_tx.read(data_length)) + + instructions.append((program_id, instruction_accounts, data)) + + return instructions diff --git a/core/src/apps/solana/parsing/utils.py b/core/src/apps/solana/parsing/utils.py new file mode 100644 index 000000000..187a8809f --- /dev/null +++ b/core/src/apps/solana/parsing/utils.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trezor.utils import BufferReader + + +def read_compact_u16(reader: BufferReader) -> int: + """Return the decoded value. BufferReader is advanced by the number of bytes read""" + value = size = 0 + while size < reader.remaining_count(): + elem = reader.get() + value |= (elem & 0x7F) << (size * 7) + size += 1 + if (elem & 0x80) == 0: + break + return value + + +def read_string(data: BufferReader) -> str: + """ + Reads a string from the buffer. The string is prefixed with its + length in the first 4 bytes and a 4 byte padding. + BufferReader is advanced by the number of bytes read. + """ + length = int.from_bytes(data.read(4), "little") + # padding + data.read(4) + return data.read(length).decode() diff --git a/core/src/apps/solana/sign_tx.py b/core/src/apps/solana/sign_tx.py index bd95dbe8e..9c4e32e8d 100644 --- a/core/src/apps/solana/sign_tx.py +++ b/core/src/apps/solana/sign_tx.py @@ -1,12 +1,9 @@ from typing import TYPE_CHECKING -from trezor.crypto import base58 - from apps.common.keychain import auto_keychain if TYPE_CHECKING: from trezor.messages import SolanaSignTx, SolanaSignedTx - from apps.common.keychain import Keychain @@ -17,6 +14,10 @@ async def sign_tx( ) -> SolanaSignedTx: from trezor.crypto.curve import ed25519 from trezor.messages import SolanaSignedTx + from apps.common import seed + from .parsing.parse import parse + from .instructions import handle_instructions + from trezor.utils import BufferReader signer_path = msg.signer_path_n serialized_tx = msg.serialized_tx @@ -25,5 +26,12 @@ async def sign_tx( signature = ed25519.sign(node.private_key(), serialized_tx) + addresses, blockhash, instructions = parse(BufferReader(serialized_tx)) + + signer_pub_key = seed.remove_ed25519_prefix(node.public_key()) + await handle_instructions(instructions, signer_pub_key) + + # TODO SOL: final confirmation screen, include blockhash + # TODO SOL: only one signature per request? return SolanaSignedTx(serialized_tx=serialized_tx, signature=signature) diff --git a/core/src/apps/solana/types.py b/core/src/apps/solana/types.py new file mode 100644 index 000000000..9eb4451ce --- /dev/null +++ b/core/src/apps/solana/types.py @@ -0,0 +1,9 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trezor.utils import BufferReader + + Address = tuple[bytes, int] + ProgramId = bytes + Data = BufferReader + Instruction = tuple[ProgramId, list[Address], Data] diff --git a/tests/device_tests/solana/test_address.py b/tests/device_tests/solana/test_address.py index 9062bc260..a86283af4 100644 --- a/tests/device_tests/solana/test_address.py +++ b/tests/device_tests/solana/test_address.py @@ -36,9 +36,7 @@ def test_solana_get_address(client: Client, parameters, result): client.init_device(new_session=True) actual_result = get_address( - client, - address_n=parse_path(parameters["path"]), - show_display=True + client, address_n=parse_path(parameters["path"]), show_display=True ) assert actual_result.address == result["expected_address"] diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 1a4f0d1da..5e1187654 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -38,7 +38,7 @@ def test_solana_sign_tx(client: Client, parameters, result): actual_result = sign_tx( client, signer_path_n=parse_path(parameters["signer_path"]), - serialized_tx=bytes.fromhex(parameters["serialized_tx"]) + serialized_tx=bytes.fromhex(parameters["serialized_tx"]), ) assert actual_result.signature == bytes.fromhex(result["expected_signature"])