Merge pull request #417 from vacuumlabs/cardano-improvements

Cardano improvements
pull/25/head
Tomas Susanka 5 years ago committed by GitHub
commit 20c97e85ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,4 @@
from trezor import log
from trezor.crypto import base58, crc, hashlib
from apps.cardano import cbor
@ -5,6 +6,14 @@ from apps.common import HARDENED
from apps.common.seed import remove_ed25519_prefix
def _encode_address_raw(address_data_encoded):
return base58.encode(
cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
)
def derive_address_and_node(keychain, path: list):
node = keychain.derive(path)
@ -16,12 +25,31 @@ def derive_address_and_node(keychain, path: list):
address_data = [address_root, address_attributes, address_type]
address_data_encoded = cbor.encode(address_data)
address = base58.encode(
cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
)
return (address, node)
return (_encode_address_raw(address_data_encoded), node)
def is_safe_output_address(address) -> bool:
"""
Determines whether it is safe to include the address as-is as
a tx output, preventing unintended side effects (e.g. CBOR injection)
"""
try:
address_hex = base58.decode(address)
address_unpacked = cbor.decode(address_hex)
except ValueError as e:
if __debug__:
log.exception(__name__, e)
return False
if not isinstance(address_unpacked, list) or len(address_unpacked) != 2:
return False
address_data_encoded = address_unpacked[0]
if not isinstance(address_data_encoded, bytes):
return False
return _encode_address_raw(address_data_encoded) == address
def validate_full_path(path: list) -> bool:

@ -1,10 +1,10 @@
from trezor import log, ui, wire
from trezor import log, wire
from trezor.messages.CardanoAddress import CardanoAddress
from apps.cardano import seed
from apps.cardano.address import derive_address_and_node, validate_full_path
from apps.cardano.layout import confirm_with_pagination
from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg):
@ -18,11 +18,12 @@ async def get_address(ctx, msg):
if __debug__:
log.exception(__name__, e)
raise wire.ProcessError("Deriving address failed")
if msg.show_display:
if not await confirm_with_pagination(
ctx, address, "Export address", icon=ui.ICON_SEND, icon_color=ui.GREEN
):
raise wire.ActionCancelled("Exporting cancelled")
desc = address_n_to_str(msg.address_n)
while True:
if await show_address(ctx, address, desc=desc):
break
if await show_qr(ctx, address, desc=desc):
break
return CardanoAddress(address=address)

@ -3,10 +3,77 @@ from micropython import const
from trezor import ui
from trezor.messages import ButtonRequestType, MessageType
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.confirm import CONFIRMED, ConfirmDialog
from trezor.ui.confirm import CONFIRMED, ConfirmDialog, HoldToConfirmDialog
from trezor.ui.scroll import Scrollpage, animate_swipe, paginate
from trezor.ui.text import Text
from trezor.utils import chunks
from trezor.utils import chunks, format_amount
def format_coin_amount(amount):
return "%s %s" % (format_amount(amount, 6), "ADA")
async def confirm_sending(ctx, amount, to):
to_lines = list(chunks(to, 17))
t1 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
t1.normal("Confirm sending:")
t1.bold(format_coin_amount(amount))
t1.normal("to:")
t1.bold(to_lines[0])
pages = [t1]
LINES_PER_PAGE = 4
if len(to_lines) > 1:
to_pages = list(chunks(to_lines[1:], LINES_PER_PAGE))
for page in to_pages:
t = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
for line in page:
t.bold(line)
pages.append(t)
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck)
paginator = paginate(create_renderer(ConfirmDialog), len(pages), const(0), pages)
return await ctx.wait(paginator) == CONFIRMED
async def confirm_transaction(ctx, amount, fee, network_name):
t1 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
t1.normal("Total amount:")
t1.bold(format_coin_amount(amount))
t1.normal("including fee:")
t1.bold(format_coin_amount(fee))
t2 = Text("Confirm transaction", ui.ICON_SEND, icon_color=ui.GREEN)
t2.normal("Network:")
t2.bold(network_name)
pages = [t1, t2]
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck)
paginator = paginate(
create_renderer(HoldToConfirmDialog), len(pages), const(0), pages
)
return await ctx.wait(paginator) == CONFIRMED
def create_renderer(confirmation_wrapper):
@ui.layout
async def page_renderer(page: int, page_count: int, pages: list):
# for some reason page index can be equal to page count
if page >= page_count:
page = page_count - 1
content = Scrollpage(pages[page], page, page_count)
if page + 1 >= page_count:
return await confirmation_wrapper(content)
else:
content.render()
await animate_swipe()
return page_renderer
async def confirm_with_pagination(

@ -1,74 +1,63 @@
from trezor import log, ui, wire
from micropython import const
from trezor import log, wire
from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from trezor.ui.text import BR
from apps.cardano import cbor, seed
from apps.cardano.address import derive_address_and_node, validate_full_path
from apps.cardano.layout import confirm_with_pagination, progress
from apps.common.layout import address_n_to_str, split_address
from apps.cardano.address import (
derive_address_and_node,
is_safe_output_address,
validate_full_path,
)
from apps.cardano.layout import confirm_sending, confirm_transaction, progress
from apps.common.paths import validate_path
from apps.common.seed import remove_ed25519_prefix
from apps.homescreen.homescreen import display_homescreen
# the maximum allowed change address. this should be large enough for normal
# use and still allow to quickly brute-force the correct bip32 path
MAX_CHANGE_ADDRESS_INDEX = const(1000000)
ACCOUNT_PREFIX_DEPTH = const(2)
KNOWN_PROTOCOL_MAGICS = {764824073: "Mainnet", 1097911063: "Testnet"}
# we consider addresses from the external chain as possible change addresses as well
def is_change(output, inputs):
for input in inputs:
inp = input.address_n
if (
not output[:ACCOUNT_PREFIX_DEPTH] == inp[:ACCOUNT_PREFIX_DEPTH]
or not output[-2] < 2
or not output[-1] < MAX_CHANGE_ADDRESS_INDEX
):
return False
return True
async def show_tx(
ctx,
outputs: list,
outcoins: list,
change_derivation_paths: list,
change_coins: list,
fee: float,
tx_size: float,
fee: int,
network_name: str,
raw_inputs: list,
raw_outputs: list,
) -> bool:
lines = ("%s ADA" % _micro_ada_to_ada(fee), BR, "Tx size:", "%s bytes" % tx_size)
if not await confirm_with_pagination(
ctx, lines, "Confirm fee", ui.ICON_SEND, ui.GREEN
):
return False
if not await confirm_with_pagination(
ctx, "%s network" % network_name, "Confirm network", ui.ICON_SEND, ui.GREEN
):
return False
for index, output in enumerate(outputs):
if not await confirm_with_pagination(
ctx, output, "Confirm output", ui.ICON_SEND, ui.GREEN
):
return False
if is_change(raw_outputs[index].address_n, raw_inputs):
continue
if not await confirm_with_pagination(
ctx,
"%s ADA" % _micro_ada_to_ada(outcoins[index]),
"Confirm amount",
ui.ICON_SEND,
ui.GREEN,
):
if not await confirm_sending(ctx, outcoins[index], output):
return False
for index, change in enumerate(change_derivation_paths):
if not await confirm_with_pagination(
ctx,
list(split_address(address_n_to_str(change))),
"Confirm change",
ui.ICON_SEND,
ui.GREEN,
):
return False
if not await confirm_with_pagination(
ctx,
"%s ADA" % _micro_ada_to_ada(change_coins[index]),
"Confirm amount",
ui.ICON_SEND,
ui.GREEN,
):
return False
total_amount = sum(outcoins)
if not await confirm_transaction(ctx, total_amount, fee, network_name):
return False
return True
@ -100,7 +89,7 @@ async def sign_tx(ctx, msg):
# sign the transaction bundle and prepare the result
transaction = Transaction(
msg.inputs, msg.outputs, transactions, keychain, msg.network
msg.inputs, msg.outputs, transactions, keychain, msg.protocol_magic
)
tx_body, tx_hash = transaction.serialise_tx()
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
@ -115,24 +104,24 @@ async def sign_tx(ctx, msg):
ctx,
transaction.output_addresses,
transaction.outgoing_coins,
transaction.change_derivation_paths,
transaction.change_coins,
transaction.fee,
len(tx_body),
transaction.network_name,
transaction.inputs,
transaction.outputs,
):
raise wire.ActionCancelled("Signing cancelled")
return tx
def _micro_ada_to_ada(amount: float) -> float:
return amount / 1000000
class Transaction:
def __init__(
self, inputs: list, outputs: list, transactions: list, keychain, network: int
self,
inputs: list,
outputs: list,
transactions: list,
keychain,
protocol_magic: int,
):
self.inputs = inputs
self.outputs = outputs
@ -140,14 +129,9 @@ class Transaction:
self.keychain = keychain
# attributes have to be always empty in current Cardano
self.attributes = {}
if network == 1:
self.network_name = "Testnet"
self.network_magic = b"\x01\x1a\x41\x70\xcb\x17\x58\x20"
elif network == 2:
self.network_name = "Mainnet"
self.network_magic = b"\x01\x1a\x2d\x96\x4a\x09\x58\x20"
else:
raise wire.ProcessError("Unknown network index %d" % network)
self.network_name = KNOWN_PROTOCOL_MAGICS.get(protocol_magic, "Unknown")
self.protocol_magic = protocol_magic
def _process_inputs(self):
input_coins = []
@ -204,6 +188,8 @@ class Transaction:
raise wire.ProcessError(
"Each output must have address or address_n field!"
)
if not is_safe_output_address(output.address):
raise wire.ProcessError("Invalid output address!")
outgoing_coins.append(output.amount)
output_addresses.append(output.address)
@ -217,7 +203,9 @@ class Transaction:
def _build_witnesses(self, tx_aux_hash: str):
witnesses = []
for index, node in enumerate(self.nodes):
message = self.network_magic + tx_aux_hash
message = (
b"\x01" + cbor.encode(self.protocol_magic) + b"\x58\x20" + tx_aux_hash
)
signature = ed25519.sign_ext(
node.private_key(), node.private_key_ext(), message
)

@ -20,12 +20,12 @@ class CardanoSignTx(p.MessageType):
inputs: List[CardanoTxInputType] = None,
outputs: List[CardanoTxOutputType] = None,
transactions_count: int = None,
network: int = None,
protocol_magic: int = None,
) -> None:
self.inputs = inputs if inputs is not None else []
self.outputs = outputs if outputs is not None else []
self.transactions_count = transactions_count
self.network = network
self.protocol_magic = protocol_magic
@classmethod
def get_fields(cls):
@ -33,5 +33,5 @@ class CardanoSignTx(p.MessageType):
1: ('inputs', CardanoTxInputType, p.FLAG_REPEATED),
2: ('outputs', CardanoTxOutputType, p.FLAG_REPEATED),
3: ('transactions_count', p.UVarintType, 0),
4: ('network', p.UVarintType, 0),
5: ('protocol_magic', p.UVarintType, 0),
}

@ -1 +1 @@
Subproject commit 877778fc93db87544df64e4fc1fa91deee1c8334
Subproject commit b947aa3d45af58605f8f090078a2de3d5878571d
Loading…
Cancel
Save