mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
Merge pull request #417 from vacuumlabs/cardano-improvements
Cardano improvements
This commit is contained in:
commit
20c97e85ad
@ -1,3 +1,4 @@
|
||||
from trezor import log
|
||||
from trezor.crypto import base58, crc, hashlib
|
||||
|
||||
from apps.cardano import cbor
|
||||
@ -5,6 +6,14 @@ from apps.common import HARDENED
|
||||
from apps.common.seed import remove_ed25519_prefix
|
||||
|
||||
|
||||
def _encode_address_raw(address_data_encoded):
|
||||
return base58.encode(
|
||||
cbor.encode(
|
||||
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def derive_address_and_node(keychain, path: list):
|
||||
node = keychain.derive(path)
|
||||
|
||||
@ -16,12 +25,31 @@ def derive_address_and_node(keychain, path: list):
|
||||
address_data = [address_root, address_attributes, address_type]
|
||||
address_data_encoded = cbor.encode(address_data)
|
||||
|
||||
address = base58.encode(
|
||||
cbor.encode(
|
||||
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
|
||||
)
|
||||
)
|
||||
return (address, node)
|
||||
return (_encode_address_raw(address_data_encoded), node)
|
||||
|
||||
|
||||
def is_safe_output_address(address) -> bool:
|
||||
"""
|
||||
Determines whether it is safe to include the address as-is as
|
||||
a tx output, preventing unintended side effects (e.g. CBOR injection)
|
||||
"""
|
||||
try:
|
||||
address_hex = base58.decode(address)
|
||||
address_unpacked = cbor.decode(address_hex)
|
||||
except ValueError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
return False
|
||||
|
||||
if not isinstance(address_unpacked, list) or len(address_unpacked) != 2:
|
||||
return False
|
||||
|
||||
address_data_encoded = address_unpacked[0]
|
||||
|
||||
if not isinstance(address_data_encoded, bytes):
|
||||
return False
|
||||
|
||||
return _encode_address_raw(address_data_encoded) == address
|
||||
|
||||
|
||||
def validate_full_path(path: list) -> bool:
|
||||
|
@ -1,10 +1,10 @@
|
||||
from trezor import log, ui, wire
|
||||
from trezor import log, wire
|
||||
from trezor.messages.CardanoAddress import CardanoAddress
|
||||
|
||||
from apps.cardano import seed
|
||||
from apps.cardano.address import derive_address_and_node, validate_full_path
|
||||
from apps.cardano.layout import confirm_with_pagination
|
||||
from apps.common import paths
|
||||
from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||
|
||||
|
||||
async def get_address(ctx, msg):
|
||||
@ -18,11 +18,12 @@ async def get_address(ctx, msg):
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
raise wire.ProcessError("Deriving address failed")
|
||||
|
||||
if msg.show_display:
|
||||
if not await confirm_with_pagination(
|
||||
ctx, address, "Export address", icon=ui.ICON_SEND, icon_color=ui.GREEN
|
||||
):
|
||||
raise wire.ActionCancelled("Exporting cancelled")
|
||||
desc = address_n_to_str(msg.address_n)
|
||||
while True:
|
||||
if await show_address(ctx, address, desc=desc):
|
||||
break
|
||||
if await show_qr(ctx, address, desc=desc):
|
||||
break
|
||||
|
||||
return CardanoAddress(address=address)
|
||||
|
@ -3,10 +3,77 @@ from micropython import const
|
||||
from trezor import ui
|
||||
from trezor.messages import ButtonRequestType, MessageType
|
||||
from trezor.messages.ButtonRequest import ButtonRequest
|
||||
from trezor.ui.confirm import CONFIRMED, ConfirmDialog
|
||||
from trezor.ui.confirm import CONFIRMED, ConfirmDialog, HoldToConfirmDialog
|
||||
from trezor.ui.scroll import Scrollpage, animate_swipe, paginate
|
||||
from trezor.ui.text import Text
|
||||
from trezor.utils import chunks
|
||||
from trezor.utils import chunks, format_amount
|
||||
|
||||
|
||||
def format_coin_amount(amount):
|
||||
return "%s %s" % (format_amount(amount, 6), "ADA")
|
||||
|
||||
|
||||
async def confirm_sending(ctx, amount, to):
|
||||
to_lines = list(chunks(to, 17))
|
||||
|
||||
t1 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
|
||||
t1.normal("Confirm sending:")
|
||||
t1.bold(format_coin_amount(amount))
|
||||
t1.normal("to:")
|
||||
t1.bold(to_lines[0])
|
||||
pages = [t1]
|
||||
|
||||
LINES_PER_PAGE = 4
|
||||
if len(to_lines) > 1:
|
||||
to_pages = list(chunks(to_lines[1:], LINES_PER_PAGE))
|
||||
for page in to_pages:
|
||||
t = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
|
||||
for line in page:
|
||||
t.bold(line)
|
||||
pages.append(t)
|
||||
|
||||
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck)
|
||||
|
||||
paginator = paginate(create_renderer(ConfirmDialog), len(pages), const(0), pages)
|
||||
return await ctx.wait(paginator) == CONFIRMED
|
||||
|
||||
|
||||
async def confirm_transaction(ctx, amount, fee, network_name):
|
||||
t1 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
|
||||
t1.normal("Total amount:")
|
||||
t1.bold(format_coin_amount(amount))
|
||||
t1.normal("including fee:")
|
||||
t1.bold(format_coin_amount(fee))
|
||||
|
||||
t2 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
|
||||
t2.normal("Network:")
|
||||
t2.bold(network_name)
|
||||
|
||||
pages = [t1, t2]
|
||||
|
||||
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck)
|
||||
|
||||
paginator = paginate(
|
||||
create_renderer(HoldToConfirmDialog), len(pages), const(0), pages
|
||||
)
|
||||
return await ctx.wait(paginator) == CONFIRMED
|
||||
|
||||
|
||||
def create_renderer(confirmation_wrapper):
|
||||
@ui.layout
|
||||
async def page_renderer(page: int, page_count: int, pages: list):
|
||||
# for some reason page index can be equal to page count
|
||||
if page >= page_count:
|
||||
page = page_count - 1
|
||||
|
||||
content = Scrollpage(pages[page], page, page_count)
|
||||
if page + 1 >= page_count:
|
||||
return await confirmation_wrapper(content)
|
||||
else:
|
||||
content.render()
|
||||
await animate_swipe()
|
||||
|
||||
return page_renderer
|
||||
|
||||
|
||||
async def confirm_with_pagination(
|
||||
|
@ -1,73 +1,62 @@
|
||||
from trezor import log, ui, wire
|
||||
from micropython import const
|
||||
|
||||
from trezor import log, wire
|
||||
from trezor.crypto import base58, hashlib
|
||||
from trezor.crypto.curve import ed25519
|
||||
from trezor.messages.CardanoSignedTx import CardanoSignedTx
|
||||
from trezor.messages.CardanoTxRequest import CardanoTxRequest
|
||||
from trezor.messages.MessageType import CardanoTxAck
|
||||
from trezor.ui.text import BR
|
||||
|
||||
from apps.cardano import cbor, seed
|
||||
from apps.cardano.address import derive_address_and_node, validate_full_path
|
||||
from apps.cardano.layout import confirm_with_pagination, progress
|
||||
from apps.common.layout import address_n_to_str, split_address
|
||||
from apps.cardano.address import (
|
||||
derive_address_and_node,
|
||||
is_safe_output_address,
|
||||
validate_full_path,
|
||||
)
|
||||
from apps.cardano.layout import confirm_sending, confirm_transaction, progress
|
||||
from apps.common.paths import validate_path
|
||||
from apps.common.seed import remove_ed25519_prefix
|
||||
from apps.homescreen.homescreen import display_homescreen
|
||||
|
||||
# the maximum allowed change address. this should be large enough for normal
|
||||
# use and still allow to quickly brute-force the correct bip32 path
|
||||
MAX_CHANGE_ADDRESS_INDEX = const(1000000)
|
||||
ACCOUNT_PREFIX_DEPTH = const(2)
|
||||
|
||||
KNOWN_PROTOCOL_MAGICS = {764824073: "Mainnet", 1097911063: "Testnet"}
|
||||
|
||||
|
||||
# we consider addresses from the external chain as possible change addresses as well
|
||||
def is_change(output, inputs):
|
||||
for input in inputs:
|
||||
inp = input.address_n
|
||||
if (
|
||||
not output[:ACCOUNT_PREFIX_DEPTH] == inp[:ACCOUNT_PREFIX_DEPTH]
|
||||
or not output[-2] < 2
|
||||
or not output[-1] < MAX_CHANGE_ADDRESS_INDEX
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def show_tx(
|
||||
ctx,
|
||||
outputs: list,
|
||||
outcoins: list,
|
||||
change_derivation_paths: list,
|
||||
change_coins: list,
|
||||
fee: float,
|
||||
tx_size: float,
|
||||
fee: int,
|
||||
network_name: str,
|
||||
raw_inputs: list,
|
||||
raw_outputs: list,
|
||||
) -> bool:
|
||||
lines = ("%s ADA" % _micro_ada_to_ada(fee), BR, "Tx size:", "%s bytes" % tx_size)
|
||||
if not await confirm_with_pagination(
|
||||
ctx, lines, "Confirm fee", ui.ICON_SEND, ui.GREEN
|
||||
):
|
||||
return False
|
||||
|
||||
if not await confirm_with_pagination(
|
||||
ctx, "%s network" % network_name, "Confirm network", ui.ICON_SEND, ui.GREEN
|
||||
):
|
||||
return False
|
||||
|
||||
for index, output in enumerate(outputs):
|
||||
if not await confirm_with_pagination(
|
||||
ctx, output, "Confirm output", ui.ICON_SEND, ui.GREEN
|
||||
):
|
||||
if is_change(raw_outputs[index].address_n, raw_inputs):
|
||||
continue
|
||||
|
||||
if not await confirm_sending(ctx, outcoins[index], output):
|
||||
return False
|
||||
|
||||
if not await confirm_with_pagination(
|
||||
ctx,
|
||||
"%s ADA" % _micro_ada_to_ada(outcoins[index]),
|
||||
"Confirm amount",
|
||||
ui.ICON_SEND,
|
||||
ui.GREEN,
|
||||
):
|
||||
return False
|
||||
|
||||
for index, change in enumerate(change_derivation_paths):
|
||||
if not await confirm_with_pagination(
|
||||
ctx,
|
||||
list(split_address(address_n_to_str(change))),
|
||||
"Confirm change",
|
||||
ui.ICON_SEND,
|
||||
ui.GREEN,
|
||||
):
|
||||
return False
|
||||
|
||||
if not await confirm_with_pagination(
|
||||
ctx,
|
||||
"%s ADA" % _micro_ada_to_ada(change_coins[index]),
|
||||
"Confirm amount",
|
||||
ui.ICON_SEND,
|
||||
ui.GREEN,
|
||||
):
|
||||
total_amount = sum(outcoins)
|
||||
if not await confirm_transaction(ctx, total_amount, fee, network_name):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -100,7 +89,7 @@ async def sign_tx(ctx, msg):
|
||||
|
||||
# sign the transaction bundle and prepare the result
|
||||
transaction = Transaction(
|
||||
msg.inputs, msg.outputs, transactions, keychain, msg.network
|
||||
msg.inputs, msg.outputs, transactions, keychain, msg.protocol_magic
|
||||
)
|
||||
tx_body, tx_hash = transaction.serialise_tx()
|
||||
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
|
||||
@ -115,24 +104,24 @@ async def sign_tx(ctx, msg):
|
||||
ctx,
|
||||
transaction.output_addresses,
|
||||
transaction.outgoing_coins,
|
||||
transaction.change_derivation_paths,
|
||||
transaction.change_coins,
|
||||
transaction.fee,
|
||||
len(tx_body),
|
||||
transaction.network_name,
|
||||
transaction.inputs,
|
||||
transaction.outputs,
|
||||
):
|
||||
raise wire.ActionCancelled("Signing cancelled")
|
||||
|
||||
return tx
|
||||
|
||||
|
||||
def _micro_ada_to_ada(amount: float) -> float:
|
||||
return amount / 1000000
|
||||
|
||||
|
||||
class Transaction:
|
||||
def __init__(
|
||||
self, inputs: list, outputs: list, transactions: list, keychain, network: int
|
||||
self,
|
||||
inputs: list,
|
||||
outputs: list,
|
||||
transactions: list,
|
||||
keychain,
|
||||
protocol_magic: int,
|
||||
):
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
@ -140,14 +129,9 @@ class Transaction:
|
||||
self.keychain = keychain
|
||||
# attributes have to be always empty in current Cardano
|
||||
self.attributes = {}
|
||||
if network == 1:
|
||||
self.network_name = "Testnet"
|
||||
self.network_magic = b"\x01\x1a\x41\x70\xcb\x17\x58\x20"
|
||||
elif network == 2:
|
||||
self.network_name = "Mainnet"
|
||||
self.network_magic = b"\x01\x1a\x2d\x96\x4a\x09\x58\x20"
|
||||
else:
|
||||
raise wire.ProcessError("Unknown network index %d" % network)
|
||||
|
||||
self.network_name = KNOWN_PROTOCOL_MAGICS.get(protocol_magic, "Unknown")
|
||||
self.protocol_magic = protocol_magic
|
||||
|
||||
def _process_inputs(self):
|
||||
input_coins = []
|
||||
@ -204,6 +188,8 @@ class Transaction:
|
||||
raise wire.ProcessError(
|
||||
"Each output must have address or address_n field!"
|
||||
)
|
||||
if not is_safe_output_address(output.address):
|
||||
raise wire.ProcessError("Invalid output address!")
|
||||
|
||||
outgoing_coins.append(output.amount)
|
||||
output_addresses.append(output.address)
|
||||
@ -217,7 +203,9 @@ class Transaction:
|
||||
def _build_witnesses(self, tx_aux_hash: str):
|
||||
witnesses = []
|
||||
for index, node in enumerate(self.nodes):
|
||||
message = self.network_magic + tx_aux_hash
|
||||
message = (
|
||||
b"\x01" + cbor.encode(self.protocol_magic) + b"\x58\x20" + tx_aux_hash
|
||||
)
|
||||
signature = ed25519.sign_ext(
|
||||
node.private_key(), node.private_key_ext(), message
|
||||
)
|
||||
|
@ -20,12 +20,12 @@ class CardanoSignTx(p.MessageType):
|
||||
inputs: List[CardanoTxInputType] = None,
|
||||
outputs: List[CardanoTxOutputType] = None,
|
||||
transactions_count: int = None,
|
||||
network: int = None,
|
||||
protocol_magic: int = None,
|
||||
) -> None:
|
||||
self.inputs = inputs if inputs is not None else []
|
||||
self.outputs = outputs if outputs is not None else []
|
||||
self.transactions_count = transactions_count
|
||||
self.network = network
|
||||
self.protocol_magic = protocol_magic
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
@ -33,5 +33,5 @@ class CardanoSignTx(p.MessageType):
|
||||
1: ('inputs', CardanoTxInputType, p.FLAG_REPEATED),
|
||||
2: ('outputs', CardanoTxOutputType, p.FLAG_REPEATED),
|
||||
3: ('transactions_count', p.UVarintType, 0),
|
||||
4: ('network', p.UVarintType, 0),
|
||||
5: ('protocol_magic', p.UVarintType, 0),
|
||||
}
|
||||
|
2
vendor/trezor-common
vendored
2
vendor/trezor-common
vendored
@ -1 +1 @@
|
||||
Subproject commit 877778fc93db87544df64e4fc1fa91deee1c8334
|
||||
Subproject commit b947aa3d45af58605f8f090078a2de3d5878571d
|
Loading…
Reference in New Issue
Block a user