You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/src/apps/cardano/sign_tx.py

272 lines
8.7 KiB

from micropython import const
from trezor import log, wire
from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from apps.cardano import CURVE, cbor, seed
from apps.cardano.address import (
derive_address_and_node,
is_safe_output_address,
validate_full_path,
)
from apps.cardano.layout import confirm_sending, confirm_transaction, progress
from apps.common.paths import validate_path
from apps.common.seed import remove_ed25519_prefix
from apps.homescreen.homescreen import display_homescreen
# the maximum allowed change address. this should be large enough for normal
# use and still allow to quickly brute-force the correct bip32 path
MAX_CHANGE_ADDRESS_INDEX = const(1000000)
ACCOUNT_PREFIX_DEPTH = const(2)
KNOWN_PROTOCOL_MAGICS = {764824073: "Mainnet", 1097911063: "Testnet"}
# we consider addresses from the external chain as possible change addresses as well
def is_change(output, inputs):
for input in inputs:
inp = input.address_n
if (
not output[:ACCOUNT_PREFIX_DEPTH] == inp[:ACCOUNT_PREFIX_DEPTH]
or not output[-2] < 2
or not output[-1] < MAX_CHANGE_ADDRESS_INDEX
):
return False
return True
async def show_tx(
ctx,
outputs: list,
outcoins: list,
fee: int,
network_name: str,
raw_inputs: list,
raw_outputs: list,
) -> bool:
for index, output in enumerate(outputs):
if is_change(raw_outputs[index].address_n, raw_inputs):
continue
if not await confirm_sending(ctx, outcoins[index], output):
return False
total_amount = sum(outcoins)
if not await confirm_transaction(ctx, total_amount, fee, network_name):
return False
return True
async def request_transaction(ctx, tx_req: CardanoTxRequest, index: int):
tx_req.tx_index = index
return await ctx.call(tx_req, CardanoTxAck)
async def sign_tx(ctx, msg):
keychain = await seed.get_keychain(ctx)
progress.init(msg.transactions_count, "Loading data")
try:
# request transactions
transactions = []
tx_req = CardanoTxRequest()
for index in range(msg.transactions_count):
progress.advance()
tx_ack = await request_transaction(ctx, tx_req, index)
transactions.append(tx_ack.transaction)
# clear progress bar
display_homescreen()
for i in msg.inputs:
await validate_path(ctx, validate_full_path, keychain, i.address_n, CURVE)
# sign the transaction bundle and prepare the result
transaction = Transaction(
msg.inputs, msg.outputs, transactions, keychain, msg.protocol_magic
)
tx_body, tx_hash = transaction.serialise_tx()
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
except ValueError as e:
if __debug__:
log.exception(__name__, e)
raise wire.ProcessError("Signing failed")
# display the transaction in UI
if not await show_tx(
ctx,
transaction.output_addresses,
transaction.outgoing_coins,
transaction.fee,
transaction.network_name,
transaction.inputs,
transaction.outputs,
):
raise wire.ActionCancelled("Signing cancelled")
return tx
class Transaction:
def __init__(
self,
inputs: list,
outputs: list,
transactions: list,
keychain,
protocol_magic: int,
):
self.inputs = inputs
self.outputs = outputs
self.transactions = transactions
self.keychain = keychain
# attributes have to be always empty in current Cardano
self.attributes = {}
self.network_name = KNOWN_PROTOCOL_MAGICS.get(protocol_magic, "Unknown")
self.protocol_magic = protocol_magic
def _process_inputs(self):
input_coins = []
input_hashes = []
output_indexes = []
types = []
tx_data = {}
for raw_transaction in self.transactions:
tx_hash = hashlib.blake2b(data=bytes(raw_transaction), outlen=32).digest()
tx_data[tx_hash] = cbor.decode(raw_transaction)
for input in self.inputs:
input_hashes.append(input.prev_hash)
output_indexes.append(input.prev_index)
types.append(input.type or 0)
nodes = []
for input in self.inputs:
_, node = derive_address_and_node(self.keychain, input.address_n)
nodes.append(node)
for index, output_index in enumerate(output_indexes):
tx_hash = bytes(input_hashes[index])
if tx_hash in tx_data:
tx = tx_data[tx_hash]
outputs = tx[1]
amount = outputs[output_index][1]
input_coins.append(amount)
else:
raise wire.ProcessError("No tx data sent for input " + str(index))
self.input_coins = input_coins
self.nodes = nodes
self.types = types
self.input_hashes = input_hashes
self.output_indexes = output_indexes
def _process_outputs(self):
change_addresses = []
change_derivation_paths = []
output_addresses = []
outgoing_coins = []
change_coins = []
for output in self.outputs:
if output.address_n:
address, _ = derive_address_and_node(self.keychain, output.address_n)
change_addresses.append(address)
change_derivation_paths.append(output.address_n)
change_coins.append(output.amount)
else:
if output.address is None:
raise wire.ProcessError(
"Each output must have address or address_n field!"
)
if not is_safe_output_address(output.address):
raise wire.ProcessError("Invalid output address!")
outgoing_coins.append(output.amount)
output_addresses.append(output.address)
self.change_addresses = change_addresses
self.output_addresses = output_addresses
self.outgoing_coins = outgoing_coins
self.change_coins = change_coins
self.change_derivation_paths = change_derivation_paths
def _build_witnesses(self, tx_aux_hash: str):
witnesses = []
for index, node in enumerate(self.nodes):
message = (
b"\x01" + cbor.encode(self.protocol_magic) + b"\x58\x20" + tx_aux_hash
)
signature = ed25519.sign_ext(
node.private_key(), node.private_key_ext(), message
)
extended_public_key = (
remove_ed25519_prefix(node.public_key()) + node.chain_code()
)
witnesses.append(
[
self.types[index],
cbor.Tagged(24, cbor.encode([extended_public_key, signature])),
]
)
return witnesses
@staticmethod
def compute_fee(input_coins: list, outgoing_coins: list, change_coins: list):
input_coins_sum = sum(input_coins)
outgoing_coins_sum = sum(outgoing_coins)
change_coins_sum = sum(change_coins)
return input_coins_sum - outgoing_coins_sum - change_coins_sum
def serialise_tx(self):
self._process_inputs()
self._process_outputs()
inputs_cbor = []
for i, output_index in enumerate(self.output_indexes):
inputs_cbor.append(
[
self.types[i],
cbor.Tagged(24, cbor.encode([self.input_hashes[i], output_index])),
]
)
inputs_cbor = cbor.IndefiniteLengthArray(inputs_cbor)
outputs_cbor = []
for index, address in enumerate(self.output_addresses):
outputs_cbor.append(
[cbor.Raw(base58.decode(address)), self.outgoing_coins[index]]
)
for index, address in enumerate(self.change_addresses):
outputs_cbor.append(
[cbor.Raw(base58.decode(address)), self.change_coins[index]]
)
outputs_cbor = cbor.IndefiniteLengthArray(outputs_cbor)
tx_aux_cbor = [inputs_cbor, outputs_cbor, self.attributes]
tx_hash = hashlib.blake2b(data=cbor.encode(tx_aux_cbor), outlen=32).digest()
witnesses = self._build_witnesses(tx_hash)
tx_body = cbor.encode([tx_aux_cbor, witnesses])
self.fee = self.compute_fee(
self.input_coins, self.outgoing_coins, self.change_coins
)
return tx_body, tx_hash