1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 23:08:14 +00:00

cardano: optimize transaction signing

squash merge. closes #94 and #130

Squashed commit of the following:

commit 1adc68051df863f5ce3767f6594d91fc9ad7bb16
Author: refi93 <rafael.korbas@gmail.com>
Date:   Tue May 7 00:13:07 2019 +0200

    cr fix

commit 833b831754aa3dfce4630e9d64e7fb6c5147ceb1
Author: refi93 <rafael.korbas@gmail.com>
Date:   Mon May 6 21:59:49 2019 +0200

    cr fix style

commit 2962e697c993d3df9a53dcc6cebc9148c1d701f5
Author: Matúš Juran <matus.juran@vacuumlabs.com>
Date:   Thu May 2 16:57:55 2019 +0200

    cardano: optimize transaction signing

    Optimize the storage of previous transactions. Instead of passing a list
    of transactions to the Transaction object, verify all inputs beforehand.
    Stop creating helper lists when serializing the transaction. This allows
    to process a few more inputs.
This commit is contained in:
matejcik 2019-05-07 17:12:27 +02:00
parent 17a7a92b7f
commit afdfd5fbdb

View File

@ -73,13 +73,33 @@ async def sign_tx(ctx, msg):
progress.init(msg.transactions_count, "Loading data")
try:
attested = len(msg.inputs) * [False]
input_coins_sum = 0
# request transactions
transactions = []
tx_req = CardanoTxRequest()
for index in range(msg.transactions_count):
progress.advance()
tx_ack = await request_transaction(ctx, tx_req, index)
transactions.append(tx_ack.transaction)
tx_hash = hashlib.blake2b(
data=bytes(tx_ack.transaction), outlen=32
).digest()
tx_decoded = cbor.decode(tx_ack.transaction)
for i, input in enumerate(msg.inputs):
if not attested[i] and input.prev_hash == tx_hash:
attested[i] = True
outputs = tx_decoded[1]
amount = outputs[input.prev_index][1]
input_coins_sum += amount
if not all(attested):
raise wire.ProcessError(
"No tx data sent for input " + str(attested.index(False))
)
transaction = Transaction(
msg.inputs, msg.outputs, keychain, msg.protocol_magic, input_coins_sum
)
# clear progress bar
display_homescreen()
@ -88,9 +108,6 @@ async def sign_tx(ctx, msg):
await validate_path(ctx, validate_full_path, keychain, i.address_n, CURVE)
# sign the transaction bundle and prepare the result
transaction = Transaction(
msg.inputs, msg.outputs, transactions, keychain, msg.protocol_magic
)
tx_body, tx_hash = transaction.serialise_tx()
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
@ -119,56 +136,19 @@ class Transaction:
self,
inputs: list,
outputs: list,
transactions: list,
keychain,
protocol_magic: int,
input_coins_sum: int,
):
self.inputs = inputs
self.outputs = outputs
self.transactions = transactions
self.keychain = keychain
# attributes have to be always empty in current Cardano
self.attributes = {}
self.network_name = KNOWN_PROTOCOL_MAGICS.get(protocol_magic, "Unknown")
self.protocol_magic = protocol_magic
def _process_inputs(self):
input_coins = []
input_hashes = []
output_indexes = []
types = []
tx_data = {}
for raw_transaction in self.transactions:
tx_hash = hashlib.blake2b(data=bytes(raw_transaction), outlen=32).digest()
tx_data[tx_hash] = cbor.decode(raw_transaction)
for input in self.inputs:
input_hashes.append(input.prev_hash)
output_indexes.append(input.prev_index)
types.append(input.type or 0)
nodes = []
for input in self.inputs:
_, node = derive_address_and_node(self.keychain, input.address_n)
nodes.append(node)
for index, output_index in enumerate(output_indexes):
tx_hash = bytes(input_hashes[index])
if tx_hash in tx_data:
tx = tx_data[tx_hash]
outputs = tx[1]
amount = outputs[output_index][1]
input_coins.append(amount)
else:
raise wire.ProcessError("No tx data sent for input " + str(index))
self.input_coins = input_coins
self.nodes = nodes
self.types = types
self.input_hashes = input_hashes
self.output_indexes = output_indexes
self.input_coins_sum = input_coins_sum
def _process_outputs(self):
change_addresses = []
@ -202,7 +182,8 @@ class Transaction:
def _build_witnesses(self, tx_aux_hash: str):
witnesses = []
for index, node in enumerate(self.nodes):
for input in self.inputs:
_, node = derive_address_and_node(self.keychain, input.address_n)
message = (
b"\x01" + cbor.encode(self.protocol_magic) + b"\x58\x20" + tx_aux_hash
)
@ -214,7 +195,7 @@ class Transaction:
)
witnesses.append(
[
self.types[index],
(input.type or 0),
cbor.Tagged(24, cbor.encode([extended_public_key, signature])),
]
)
@ -222,8 +203,7 @@ class Transaction:
return witnesses
@staticmethod
def compute_fee(input_coins: list, outgoing_coins: list, change_coins: list):
input_coins_sum = sum(input_coins)
def compute_fee(input_coins_sum: int, outgoing_coins: list, change_coins: list):
outgoing_coins_sum = sum(outgoing_coins)
change_coins_sum = sum(change_coins)
@ -231,15 +211,14 @@ class Transaction:
def serialise_tx(self):
self._process_inputs()
self._process_outputs()
inputs_cbor = []
for i, output_index in enumerate(self.output_indexes):
for input in self.inputs:
inputs_cbor.append(
[
self.types[i],
cbor.Tagged(24, cbor.encode([self.input_hashes[i], output_index])),
(input.type or 0),
cbor.Tagged(24, cbor.encode([input.prev_hash, input.prev_index])),
]
)
@ -265,7 +244,7 @@ class Transaction:
tx_body = cbor.encode([tx_aux_cbor, witnesses])
self.fee = self.compute_fee(
self.input_coins, self.outgoing_coins, self.change_coins
self.input_coins_sum, self.outgoing_coins, self.change_coins
)
return tx_body, tx_hash