1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 23:38:09 +00:00

core/cardano: use caching decorators and new Keychain API for Cardano as well

This commit is contained in:
matejcik 2020-04-20 11:37:47 +02:00 committed by matejcik
parent 7541d529a3
commit fd9e945308
6 changed files with 52 additions and 45 deletions

View File

@ -7,9 +7,8 @@ from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg): @seed.with_keychain
keychain = await seed.get_keychain(ctx) async def get_address(ctx, msg, keychain: seed.Keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
try: try:

View File

@ -10,9 +10,8 @@ from apps.common import layout, paths
from apps.common.seed import remove_ed25519_prefix from apps.common.seed import remove_ed25519_prefix
async def get_public_key(ctx, msg): @seed.with_keychain
keychain = await seed.get_keychain(ctx) async def get_public_key(ctx, msg, keychain: seed.Keychain):
await paths.validate_path( await paths.validate_path(
ctx, ctx,
paths.validate_path_for_get_public_key, paths.validate_path_for_get_public_key,

View File

@ -3,54 +3,64 @@ from storage import cache
from trezor import wire from trezor import wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.cardano import CURVE, SEED_NAMESPACE from apps.cardano import SEED_NAMESPACE
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.passphrase import get as get_passphrase from apps.common.passphrase import get as get_passphrase
if False:
from typing import Tuple
from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain
class Keychain: class Keychain:
def __init__(self, path: list, root: bip32.HDNode): """Cardano keychain hard-coded to SEED_NAMESPACE."""
self.path = path
def __init__(self, root: bip32.HDNode) -> None:
self.root = root self.root = root
def validate_path(self, checked_path: list, checked_curve: str): def match_path(self, path: Bip32Path) -> Tuple[int, Bip32Path]:
if checked_curve != CURVE or checked_path[:2] != SEED_NAMESPACE: if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE:
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
return 0, path[len(SEED_NAMESPACE) :]
def derive(self, node_path: list) -> bip32.HDNode: def derive(self, node_path: Bip32Path) -> bip32.HDNode:
# check we are in the cardano namespace _, suffix = self.match_path(node_path)
prefix = node_path[: len(self.path)]
suffix = node_path[len(self.path) :]
if prefix != self.path:
raise wire.DataError("Forbidden key path")
# derive child node from the root # derive child node from the root
node = self.root.clone() node = self.root.clone()
for i in suffix: for i in suffix:
node.derive_cardano(i) node.derive_cardano(i)
return node return node
# XXX the root node remains in session cache so we should not delete it
# def __del__(self) -> None:
# self.root.__del__()
@cache.stored_async(cache.APP_CARDANO_ROOT)
async def get_keychain(ctx: wire.Context) -> Keychain: async def get_keychain(ctx: wire.Context) -> Keychain:
root = cache.get(cache.APP_CARDANO_ROOT)
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if root is None: passphrase = await get_passphrase(ctx)
passphrase = await get_passphrase(ctx) if mnemonic.is_bip39():
if mnemonic.is_bip39(): # derive the root node from mnemonic and passphrase
# derive the root node from mnemonic and passphrase root = bip32.from_mnemonic_cardano(mnemonic.get_secret().decode(), passphrase)
root = bip32.from_mnemonic_cardano( else:
mnemonic.get_secret().decode(), passphrase seed = mnemonic.get_seed(passphrase)
) root = bip32.from_seed(seed, "ed25519 cardano seed")
else:
seed = mnemonic.get_seed(passphrase)
root = bip32.from_seed(seed, "ed25519 cardano seed")
# derive the namespaced root node # derive the namespaced root node
for i in SEED_NAMESPACE: for i in SEED_NAMESPACE:
root.derive_cardano(i) root.derive_cardano(i)
storage.cache.set(cache.APP_CARDANO_ROOT, root)
keychain = Keychain(SEED_NAMESPACE, root) keychain = Keychain(root)
return keychain return keychain
def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx)
return await func(ctx, msg, keychain)
return wrapper

View File

@ -63,9 +63,8 @@ async def request_transaction(ctx, tx_req: CardanoTxRequest, index: int):
return await ctx.call(tx_req, CardanoTxAck) return await ctx.call(tx_req, CardanoTxAck)
async def sign_tx(ctx, msg): @seed.with_keychain
keychain = await seed.get_keychain(ctx) async def sign_tx(ctx, msg, keychain: seed.Keychain):
progress.init(msg.transactions_count, "Loading data") progress.init(msg.transactions_count, "Loading data")
try: try:

View File

@ -21,7 +21,7 @@ class TestCardanoAddress(unittest.TestCase):
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
addresses = [ addresses = [
"Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j", "Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j",
@ -68,7 +68,7 @@ class TestCardanoAddress(unittest.TestCase):
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
addresses = [ addresses = [
"Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ", "Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ",
@ -116,7 +116,7 @@ class TestCardanoAddress(unittest.TestCase):
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
# 44'/1815' # 44'/1815'
address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815]) address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815])
@ -207,7 +207,7 @@ class TestCardanoAddress(unittest.TestCase):
# Check derived nodes and addresses. # Check derived nodes and addresses.
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
nodes = [ nodes = [
( (
@ -271,7 +271,7 @@ class TestCardanoAddress(unittest.TestCase):
# Check derived nodes and addresses. # Check derived nodes and addresses.
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
nodes = [ nodes = [
( (

View File

@ -13,7 +13,7 @@ class TestCardanoGetPublicKey(unittest.TestCase):
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
derivation_paths = [ derivation_paths = [
[0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000], [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000],
@ -67,7 +67,7 @@ class TestCardanoGetPublicKey(unittest.TestCase):
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
# 44'/1815'/0'/0/i # 44'/1815'/0'/0/i
derivation_paths = [ derivation_paths = [
@ -118,7 +118,7 @@ class TestCardanoGetPublicKey(unittest.TestCase):
node.derive_cardano(0x80000000 | 44) node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815) node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) keychain = Keychain(node)
# 44'/1815'/0'/0/i # 44'/1815'/0'/0/i
derivation_paths = [ derivation_paths = [