You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/cardano/sign_tx.py

252 lines
8.0 KiB

from micropython import const
from trezor import log, wire
from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxAck import CardanoTxAck
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from apps.cardano import CURVE, seed
from apps.cardano.address import (
derive_address_and_node,
is_safe_output_address,
validate_full_path,
)
from apps.cardano.layout import confirm_sending, confirm_transaction, progress
from apps.common import cbor
from apps.common.paths import validate_path
from apps.common.seed import remove_ed25519_prefix
from apps.homescreen.homescreen import display_homescreen
# the maximum allowed change address. this should be large enough for normal
# use and still allow to quickly brute-force the correct bip32 path
MAX_CHANGE_ADDRESS_INDEX = const(1000000)
ACCOUNT_PREFIX_DEPTH = const(2)
KNOWN_PROTOCOL_MAGICS = {764824073: "Mainnet", 1097911063: "Testnet"}
# we consider addresses from the external chain as possible change addresses as well
def is_change(output, inputs):
for input in inputs:
inp = input.address_n
if (
not output[:ACCOUNT_PREFIX_DEPTH] == inp[:ACCOUNT_PREFIX_DEPTH]
or not output[-2] < 2
or not output[-1] < MAX_CHANGE_ADDRESS_INDEX
):
return False
return True
async def show_tx(
ctx,
outputs: list,
outcoins: list,
fee: int,
network_name: str,
raw_inputs: list,
raw_outputs: list,
) -> bool:
for index, output in enumerate(outputs):
if is_change(raw_outputs[index].address_n, raw_inputs):
continue
if not await confirm_sending(ctx, outcoins[index], output):
return False
total_amount = sum(outcoins)
if not await confirm_transaction(ctx, total_amount, fee, network_name):
return False
return True
async def request_transaction(ctx, tx_req: CardanoTxRequest, index: int):
tx_req.tx_index = index
return await ctx.call(tx_req, CardanoTxAck)
async def sign_tx(ctx, msg):
keychain = await seed.get_keychain(ctx)
progress.init(msg.transactions_count, "Loading data")
try:
attested = len(msg.inputs) * [False]
input_coins_sum = 0
# request transactions
tx_req = CardanoTxRequest()
for index in range(msg.transactions_count):
progress.advance()
tx_ack = await request_transaction(ctx, tx_req, index)
tx_hash = hashlib.blake2b(
data=bytes(tx_ack.transaction), outlen=32
).digest()
tx_decoded = cbor.decode(tx_ack.transaction)
for i, input in enumerate(msg.inputs):
if not attested[i] and input.prev_hash == tx_hash:
attested[i] = True
outputs = tx_decoded[1]
amount = outputs[input.prev_index][1]
input_coins_sum += amount
if not all(attested):
raise wire.ProcessError(
"No tx data sent for input " + str(attested.index(False))
)
transaction = Transaction(
msg.inputs, msg.outputs, keychain, msg.protocol_magic, input_coins_sum
)
# clear progress bar
display_homescreen()
for i in msg.inputs:
await validate_path(ctx, validate_full_path, keychain, i.address_n, CURVE)
# sign the transaction bundle and prepare the result
tx_body, tx_hash = transaction.serialise_tx()
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
except ValueError as e:
if __debug__:
log.exception(__name__, e)
raise wire.ProcessError("Signing failed")
# display the transaction in UI
if not await show_tx(
ctx,
transaction.output_addresses,
transaction.outgoing_coins,
transaction.fee,
transaction.network_name,
transaction.inputs,
transaction.outputs,
):
raise wire.ActionCancelled("Signing cancelled")
return tx
class Transaction:
def __init__(
self,
inputs: list,
outputs: list,
keychain,
protocol_magic: int,
input_coins_sum: int,
):
self.inputs = inputs
self.outputs = outputs
self.keychain = keychain
# attributes have to be always empty in current Cardano
self.attributes = {}
self.network_name = KNOWN_PROTOCOL_MAGICS.get(protocol_magic, "Unknown")
self.protocol_magic = protocol_magic
self.input_coins_sum = input_coins_sum
def _process_outputs(self):
change_addresses = []
change_derivation_paths = []
output_addresses = []
outgoing_coins = []
change_coins = []
for output in self.outputs:
if output.address_n:
address, _ = derive_address_and_node(self.keychain, output.address_n)
change_addresses.append(address)
change_derivation_paths.append(output.address_n)
change_coins.append(output.amount)
else:
if output.address is None:
raise wire.ProcessError(
"Each output must have address or address_n field!"
)
if not is_safe_output_address(output.address):
raise wire.ProcessError("Invalid output address!")
outgoing_coins.append(output.amount)
output_addresses.append(output.address)
self.change_addresses = change_addresses
self.output_addresses = output_addresses
self.outgoing_coins = outgoing_coins
self.change_coins = change_coins
self.change_derivation_paths = change_derivation_paths
def _build_witnesses(self, tx_aux_hash: str):
witnesses = []
for input in self.inputs:
_, node = derive_address_and_node(self.keychain, input.address_n)
message = (
b"\x01" + cbor.encode(self.protocol_magic) + b"\x58\x20" + tx_aux_hash
)
signature = ed25519.sign_ext(
node.private_key(), node.private_key_ext(), message
)
extended_public_key = (
remove_ed25519_prefix(node.public_key()) + node.chain_code()
)
witnesses.append(
[
(input.type or 0),
cbor.Tagged(24, cbor.encode([extended_public_key, signature])),
]
)
return witnesses
@staticmethod
def compute_fee(input_coins_sum: int, outgoing_coins: list, change_coins: list):
outgoing_coins_sum = sum(outgoing_coins)
change_coins_sum = sum(change_coins)
return input_coins_sum - outgoing_coins_sum - change_coins_sum
def serialise_tx(self):
self._process_outputs()
inputs_cbor = []
for input in self.inputs:
inputs_cbor.append(
[
(input.type or 0),
cbor.Tagged(24, cbor.encode([input.prev_hash, input.prev_index])),
]
)
inputs_cbor = cbor.IndefiniteLengthArray(inputs_cbor)
outputs_cbor = []
for index, address in enumerate(self.output_addresses):
outputs_cbor.append(
[cbor.Raw(base58.decode(address)), self.outgoing_coins[index]]
)
for index, address in enumerate(self.change_addresses):
outputs_cbor.append(
[cbor.Raw(base58.decode(address)), self.change_coins[index]]
)
outputs_cbor = cbor.IndefiniteLengthArray(outputs_cbor)
tx_aux_cbor = [inputs_cbor, outputs_cbor, self.attributes]
tx_hash = hashlib.blake2b(data=cbor.encode(tx_aux_cbor), outlen=32).digest()
witnesses = self._build_witnesses(tx_hash)
tx_body = cbor.encode([tx_aux_cbor, witnesses])
self.fee = self.compute_fee(
self.input_coins_sum, self.outgoing_coins, self.change_coins
)
return tx_body, tx_hash