1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 05:28:40 +00:00

feat(core): add PoC of Solana instruction handling

This commit is contained in:
gabrielkerekes 2023-08-03 08:28:18 +02:00
parent 22b0e017e5
commit 3851170040
13 changed files with 544 additions and 10 deletions

View File

@ -0,0 +1,8 @@
from micropython import const
ADDRESS_SIZE = const(32)
ADDRESS_SIG = const(0)
ADDRESS_SIG_READ_ONLY = const(1)
ADDRESS_READ_ONLY = const(2)
ADDRESS_RW = const(3)

View File

@ -1,19 +1,21 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
from apps.common import seed
from apps.common.keychain import auto_keychain from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import SolanaGetPublicKey, SolanaPublicKey from trezor.messages import SolanaGetPublicKey, SolanaPublicKey
from apps.common.keychain import Keychain
# TODO SOL: maybe only get_address is needed? # TODO SOL: maybe only get_address is needed?
@auto_keychain(__name__) @auto_keychain(__name__)
async def get_public_key( async def get_public_key(
msg: SolanaGetPublicKey, keychain: seed.Keychain msg: SolanaGetPublicKey, keychain: Keychain
) -> SolanaPublicKey: ) -> SolanaPublicKey:
from trezor.ui.layouts import show_pubkey from trezor.ui.layouts import show_pubkey
from trezor.messages import HDNodeType, SolanaPublicKey from trezor.messages import HDNodeType, SolanaPublicKey
from apps.common import seed
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
@ -26,7 +28,7 @@ async def get_public_key(
) )
if msg.show_display: if msg.show_display:
await show_pubkey(hexlify(node.public_key).decode()) await show_pubkey(hexlify(node.public_key()).decode())
# TODO SOL: xpub? # TODO SOL: xpub?
return SolanaPublicKey(node=node_type, xpub=node.serialize_public(0)) return SolanaPublicKey(node=node_type, xpub=node.serialize_public(0))

View File

@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..types import Instruction
SYSTEM_PROGRAM_ID = "11111111111111111111111111111111"
STAKE_PROGRAM_ID = "Stake11111111111111111111111111111111111111"
SYSTEM_TRANSFER_ID = 2
async def handle_instructions(
instructions: list[Instruction], signer_pub_key: bytes
) -> None:
from trezor.crypto import base58
from trezor.wire import ProcessError
from .system_program import handle_system_program_instruction
from .stake_program import handle_stake_program_instruction
for instruction in instructions:
program_id, _, _ = instruction
encoded_program_id = base58.encode(program_id)
if encoded_program_id == SYSTEM_PROGRAM_ID:
await handle_system_program_instruction(instruction, signer_pub_key)
elif encoded_program_id == STAKE_PROGRAM_ID:
await handle_stake_program_instruction(instruction, signer_pub_key)
else:
# TODO SOL: blind signing for unknown programs
raise ProcessError(f"Unknown program id: {encoded_program_id}")

View File

@ -0,0 +1,72 @@
from typing import TYPE_CHECKING
from trezor.crypto import base58
from trezor.ui.layouts import confirm_properties
from trezor.wire import ProcessError
from ..constants import ADDRESS_READ_ONLY, ADDRESS_RW
from . import STAKE_PROGRAM_ID
if TYPE_CHECKING:
from typing import Awaitable
from ..types import Instruction
INS_INITIALIZE_STAKE = 0
def handle_stake_program_instruction(
instruction: Instruction, signer_pub_key: bytes
) -> Awaitable[None]:
program_id, _, data = instruction
assert base58.encode(program_id) == STAKE_PROGRAM_ID
assert data.remaining_count() >= 4
instruction_id = int.from_bytes(data.read(4), "little")
data.seek(0)
if instruction_id == INS_INITIALIZE_STAKE:
return _handle_initialize_stake(instruction, signer_pub_key)
else:
# TODO SOL: blind signing
raise ProcessError("Unknown stake program instruction")
def _handle_initialize_stake(instruction: Instruction, signer_address: bytes):
_, accounts, data = instruction
assert data.remaining_count() == 116
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_INITIALIZE_STAKE
# TODO SOL: validate staker, withdrawer, custodian
staker = data.read(32)
withdrawer = data.read(32)
# TODO SOL: should be signed int but from_bytes doesn't take the third arg
unix_timestamp = int.from_bytes(data.read(8), "little")
epoch = int.from_bytes(data.read(8), "little")
custodian = data.read(32)
uninitialized_stake_account, uninitialized_stake_account_type = accounts[0]
assert uninitialized_stake_account_type == ADDRESS_RW
rent_sysvar, rent_sysvar_type = accounts[1]
assert rent_sysvar_type == ADDRESS_READ_ONLY
return confirm_properties(
"initialize_stake",
"Initialize Stake",
(
("Staker", base58.encode(staker)),
("Withdrawer", base58.encode(withdrawer)),
("Unix Timestamp", str(unix_timestamp)),
("Epoch", str(epoch)),
("Custodian", base58.encode(custodian)),
("Stake Account", base58.encode(uninitialized_stake_account)),
# TODO SOL: probably doesn't need to be displayed
("Rent Sysvar", base58.encode(rent_sysvar)),
),
)

View File

@ -0,0 +1,241 @@
from typing import TYPE_CHECKING
from trezor.crypto import base58
from trezor.enums import ButtonRequestType
from trezor.strings import format_amount
from trezor.ui.layouts import confirm_output, confirm_properties
from trezor.wire import ProcessError
from ..constants import ADDRESS_RW, ADDRESS_SIG, ADDRESS_SIG_READ_ONLY
from ..parsing.utils import read_string
from . import SYSTEM_PROGRAM_ID
if TYPE_CHECKING:
from typing import Awaitable
from ..types import Instruction
INS_CREATE_ACCOUNT = 0
INS_TRANSFER = 2
INS_CREATE_ACCOUNT_WITH_SEED = 3
def handle_system_program_instruction(
instruction: Instruction, signer_pub_key: bytes
) -> Awaitable[None]:
program_id, _, data = instruction
assert base58.encode(program_id) == SYSTEM_PROGRAM_ID
assert data.remaining_count() >= 4
instruction_id = int.from_bytes(data.read(4), "little")
data.seek(0)
if instruction_id == INS_CREATE_ACCOUNT:
return _handle_create_account(instruction, signer_pub_key)
if instruction_id == INS_TRANSFER:
return _handle_transfer(instruction, signer_pub_key)
if instruction_id == INS_CREATE_ACCOUNT_WITH_SEED:
return _handle_create_account_with_seed(instruction, signer_pub_key)
else:
# TODO SOL: blind signing
raise ProcessError("Unknown system program instruction")
def _handle_create_account(
instruction: Instruction, signer_pub_key: bytes
) -> Awaitable[None]:
lamports, space, owner, funding_account, new_account = _parse_create_account(
instruction
)
_validate_create_account(funding_account, signer_pub_key)
return _show_create_account(lamports, space, owner, funding_account, new_account)
def _parse_create_account(
instruction: Instruction,
) -> tuple[int, int, bytes, bytes, bytes]:
_, accounts, data = instruction
assert data.remaining_count() == 52
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_CREATE_ACCOUNT
lamports = int.from_bytes(data.read(8), "little")
space = int.from_bytes(data.read(8), "little")
owner = data.read(32)
funding_account, funding_account_type = accounts[0]
assert funding_account_type == ADDRESS_SIG
new_account, new_account_type = accounts[1]
assert new_account_type == ADDRESS_RW
return lamports, space, owner, funding_account, new_account
def _validate_create_account(funding_account: bytes, signer_pub_key: bytes) -> None:
if funding_account != signer_pub_key:
raise ProcessError("Invalid funding account")
def _show_create_account(
lamports: int, space: int, owner: bytes, funding_account: bytes, new_account: bytes
) -> Awaitable[None]:
return confirm_properties(
"create_account",
"Create Account",
(
("Lamports", str(lamports)),
("Space", str(space)),
("Owner", base58.encode(owner)),
("Funding Account", base58.encode(funding_account)),
("New Account", base58.encode(new_account)),
),
)
def _handle_transfer(
instruction: Instruction, signer_pub_key: bytes
) -> Awaitable[None]:
amount, source, destination = _parse_transfer(instruction)
_validate_transfer(source, signer_pub_key)
return _show_transfer(destination, amount)
def _parse_transfer(instruction: Instruction) -> tuple[int, bytes, bytes]:
_, accounts, data = instruction
assert data.remaining_count() == 12
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_TRANSFER
amount = int.from_bytes(data.read(8), "little")
source, source_account_type = accounts[0]
assert source_account_type == ADDRESS_SIG
destination, destination_account_type = accounts[1]
assert destination_account_type == ADDRESS_RW
return amount, source, destination
def _validate_transfer(source: bytes, signer_pub_key: bytes):
if source != signer_pub_key:
raise ProcessError("Invalid source account")
# TODO SOL: validate max amount?
def _show_transfer(destination: bytes, amount: int) -> Awaitable[None]:
return confirm_output(
base58.encode(destination),
f"{format_amount(amount, 8)} SOL",
br_code=ButtonRequestType.Other,
)
def _handle_create_account_with_seed(
instruction: Instruction, signer_address: bytes
) -> Awaitable[None]:
(
base,
seed,
lamports,
space,
owner,
funding_account,
created_account,
base_account,
) = _parse_create_account_with_seed(instruction)
_validate_create_account_with_seed(funding_account, signer_address)
return _show_create_account_with_seed(
base,
seed,
lamports,
space,
owner,
funding_account,
created_account,
base_account,
)
def _parse_create_account_with_seed(
instruction: Instruction,
) -> tuple[bytes, str, int, int, bytes, bytes, bytes, bytes | None]:
_, accounts, data = instruction
# assert len(data) == 52
assert len(accounts) == 2
instruction_id = int.from_bytes(data.read(4), "little")
assert instruction_id == INS_CREATE_ACCOUNT_WITH_SEED
base = data.read(32)
seed = read_string(data)
lamports = int.from_bytes(data.read(8), "little")
space = int.from_bytes(data.read(8), "little")
owner = data.read(32)
funding_account, funding_account_type = accounts[0]
assert funding_account_type == ADDRESS_SIG
created_account, created_account_type = accounts[1]
assert created_account_type == ADDRESS_RW
base_account = None
if len(accounts) == 3:
base_account, base_account_type = accounts[2]
assert base_account_type == ADDRESS_SIG_READ_ONLY
return (
base,
seed,
lamports,
space,
owner,
funding_account,
created_account,
base_account,
)
def _validate_create_account_with_seed(
funding_account: bytes, signer_pub_key: bytes
) -> None:
# TODO SOL: pass for now since we don't have the proper mnemonic
pass
# if funding_account != signer_pub_key:
# raise ProcessError("Invalid funding account")
def _show_create_account_with_seed(
base: bytes,
seed: str,
lamports: int,
space: int,
owner: bytes,
funding_account: bytes,
created_account: bytes,
base_account: bytes | None,
) -> Awaitable[None]:
props = [
("Base", base58.encode(base)),
("Seed", seed),
("Lamports", str(lamports)),
("Space", str(space)),
("Owner", base58.encode(owner)),
("Funding Account", base58.encode(funding_account)),
("Created Account", base58.encode(created_account)),
]
if base_account:
props.append(("Base Account", base58.encode(base_account)))
return confirm_properties("create_account", "Create Account", props)

View File

View File

@ -0,0 +1,101 @@
# 01000103c80f8b50107e9f3e3c16a661b8c806df454a6deb293d5e8730a9d28f2f4998c68f41927b2e58cbc31ed3aa5163a7b8ca4eb5590e8dc1dc682426cd2895aa9c0a00000000000000000000000000000000000000000000000000000000000000001aea57c9906a7cad656ff61b3893abda63f4b6b210c939855e7ab6e54049213d01020200010c02000000002d310100000000
from typing import TYPE_CHECKING
from ..constants import (
ADDRESS_READ_ONLY,
ADDRESS_RW,
ADDRESS_SIG,
ADDRESS_SIG_READ_ONLY,
ADDRESS_SIZE,
)
from .parse_instructions import parse_instructions
from .utils import read_compact_u16
if TYPE_CHECKING:
from trezor.utils import BufferReader
from ..types import Address, Instruction
def parse(
serialized_tx: BufferReader,
) -> tuple[list[Address], bytes, list[Instruction]]:
# TODO SOL: signature parsing can be removed?
# num_of_signatures = decode_length(serialized_tx)
# assert num_of_signatures == 0
# signatures = []
# for i in range(num_of_signatures):
# signature = serialized_tx.read(64)
# signatures.append(signature)
(
num_required_signatures,
num_signature_read_only_addresses,
num_read_only_addresses,
) = _parse_header(serialized_tx)
addresses = _parse_addresses(
serialized_tx,
num_required_signatures,
num_signature_read_only_addresses,
num_read_only_addresses,
)
blockhash = bytes(serialized_tx.read(32))
instructions = parse_instructions(serialized_tx, addresses)
assert serialized_tx.remaining_count() == 0
return addresses, blockhash, instructions
def _parse_header(serialized_tx: BufferReader) -> tuple[int, int, int]:
num_required_signatures = int.from_bytes(serialized_tx.read(1), "big")
num_signature_read_only_addresses = int.from_bytes(serialized_tx.read(1), "big")
num_read_only_addresses = int.from_bytes(serialized_tx.read(1), "big")
return (
num_required_signatures,
num_signature_read_only_addresses,
num_read_only_addresses,
)
def _parse_addresses(
serialized_tx: BufferReader,
num_required_signatures: int,
num_signature_read_only_addresses: int,
num_read_only_addresses: int,
) -> list[tuple[bytes, int]]:
num_of_addresses = read_compact_u16(serialized_tx)
assert (
num_of_addresses
>= num_required_signatures
+ num_signature_read_only_addresses
+ num_read_only_addresses
)
addresses: list[tuple[bytes, int]] = []
for i in range(num_of_addresses):
assert ADDRESS_SIZE <= serialized_tx.remaining_count()
address = serialized_tx.read(ADDRESS_SIZE)
if i < num_required_signatures:
type = ADDRESS_SIG
elif i < num_required_signatures + num_signature_read_only_addresses:
type = ADDRESS_SIG_READ_ONLY
elif (
i
< num_required_signatures
+ num_signature_read_only_addresses
+ num_read_only_addresses
):
type = ADDRESS_RW
else:
type = ADDRESS_READ_ONLY
addresses.append((address, type))
return addresses

View File

@ -0,0 +1,35 @@
from trezor.utils import BufferReader
from .utils import read_compact_u16
def parse_instructions(
serialized_tx: BufferReader, addresses: list[tuple[bytes, int]]
) -> list[tuple[bytes, list[tuple[bytes, int]], BufferReader]]:
num_of_instructions = read_compact_u16(serialized_tx)
instructions: list[tuple[bytes, list[tuple[bytes, int]], BufferReader]] = []
for _ in range(num_of_instructions):
program_index = serialized_tx.get()
assert program_index < len(addresses)
program_id = addresses[program_index][0]
num_of_accounts = read_compact_u16(serialized_tx)
instruction_accounts: list[tuple[bytes, int]] = []
for _ in range(num_of_accounts):
assert serialized_tx.remaining_count() > 0
account_index = serialized_tx.get()
assert account_index < len(addresses)
account = addresses[account_index]
instruction_accounts.append(account)
data_length = read_compact_u16(serialized_tx)
data = BufferReader(serialized_tx.read(data_length))
instructions.append((program_id, instruction_accounts, data))
return instructions

View File

@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.utils import BufferReader
def read_compact_u16(reader: BufferReader) -> int:
"""Return the decoded value. BufferReader is advanced by the number of bytes read"""
value = size = 0
while size < reader.remaining_count():
elem = reader.get()
value |= (elem & 0x7F) << (size * 7)
size += 1
if (elem & 0x80) == 0:
break
return value
def read_string(data: BufferReader) -> str:
"""
Reads a string from the buffer. The string is prefixed with its
length in the first 4 bytes and a 4 byte padding.
BufferReader is advanced by the number of bytes read.
"""
length = int.from_bytes(data.read(4), "little")
# padding
data.read(4)
return data.read(length).decode()

View File

@ -1,12 +1,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto import base58
from apps.common.keychain import auto_keychain from apps.common.keychain import auto_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import SolanaSignTx, SolanaSignedTx from trezor.messages import SolanaSignTx, SolanaSignedTx
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -17,6 +14,10 @@ async def sign_tx(
) -> SolanaSignedTx: ) -> SolanaSignedTx:
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from trezor.messages import SolanaSignedTx from trezor.messages import SolanaSignedTx
from apps.common import seed
from .parsing.parse import parse
from .instructions import handle_instructions
from trezor.utils import BufferReader
signer_path = msg.signer_path_n signer_path = msg.signer_path_n
serialized_tx = msg.serialized_tx serialized_tx = msg.serialized_tx
@ -25,5 +26,12 @@ async def sign_tx(
signature = ed25519.sign(node.private_key(), serialized_tx) signature = ed25519.sign(node.private_key(), serialized_tx)
addresses, blockhash, instructions = parse(BufferReader(serialized_tx))
signer_pub_key = seed.remove_ed25519_prefix(node.public_key())
await handle_instructions(instructions, signer_pub_key)
# TODO SOL: final confirmation screen, include blockhash
# TODO SOL: only one signature per request? # TODO SOL: only one signature per request?
return SolanaSignedTx(serialized_tx=serialized_tx, signature=signature) return SolanaSignedTx(serialized_tx=serialized_tx, signature=signature)

View File

@ -0,0 +1,9 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.utils import BufferReader
Address = tuple[bytes, int]
ProgramId = bytes
Data = BufferReader
Instruction = tuple[ProgramId, list[Address], Data]

View File

@ -36,9 +36,7 @@ def test_solana_get_address(client: Client, parameters, result):
client.init_device(new_session=True) client.init_device(new_session=True)
actual_result = get_address( actual_result = get_address(
client, client, address_n=parse_path(parameters["path"]), show_display=True
address_n=parse_path(parameters["path"]),
show_display=True
) )
assert actual_result.address == result["expected_address"] assert actual_result.address == result["expected_address"]

View File

@ -38,7 +38,7 @@ def test_solana_sign_tx(client: Client, parameters, result):
actual_result = sign_tx( actual_result = sign_tx(
client, client,
signer_path_n=parse_path(parameters["signer_path"]), signer_path_n=parse_path(parameters["signer_path"]),
serialized_tx=bytes.fromhex(parameters["serialized_tx"]) serialized_tx=bytes.fromhex(parameters["serialized_tx"]),
) )
assert actual_result.signature == bytes.fromhex(result["expected_signature"]) assert actual_result.signature == bytes.fromhex(result["expected_signature"])