From 9ecd123bd51bc4a80e4486c87dc7ba4dc55939f3 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Wed, 14 Nov 2018 19:58:07 +0100 Subject: [PATCH] seed: add support for key namespaces --- src/apps/cardano/seed.py | 33 ++++++------ src/apps/common/seed.py | 61 +++++++++++++++++------ tests/test_apps.cardano.address.py | 22 +++++--- tests/test_apps.cardano.get_public_key.py | 6 ++- 4 files changed, 86 insertions(+), 36 deletions(-) diff --git a/src/apps/cardano/seed.py b/src/apps/cardano/seed.py index 3bcb39bce..00f342fd4 100644 --- a/src/apps/cardano/seed.py +++ b/src/apps/cardano/seed.py @@ -6,36 +6,39 @@ from apps.common.request_passphrase import protect_by_passphrase class Keychain: - def __init__(self, root: bip32.HDNode): + def __init__(self, path: list, root: bip32.HDNode): + self.path = path self.root = root - def derive(self, path: list) -> bip32.HDNode: - self.validate_path(path) + def derive(self, node_path: list) -> bip32.HDNode: + # check we are in the cardano namespace + prefix = node_path[: len(self.path)] + suffix = node_path[len(self.path) :] + if prefix != self.path: + raise wire.DataError("Forbidden key path") + # derive child node from the root node = self.root.clone() - for i in path: + for i in suffix: node.derive_cardano(i) return node - def validate_path(self, path: list) -> None: - if len(path) < 2 or len(path) > 5: - raise wire.ProcessError("Derivation path must be composed from 2-5 indices") - if path[0] != HARDENED | 44 or path[1] != HARDENED | 1815: - raise wire.ProcessError("This is not cardano derivation path") - async def get_keychain(ctx: wire.Context) -> Keychain: if not storage.is_initialized(): - # device does not have any seed raise wire.ProcessError("Device is not initialized") - # acquire passphrase + # derive the root node from mnemonic and passphrase passphrase = cache.get_passphrase() if passphrase is None: passphrase = await protect_by_passphrase(ctx) cache.set_passphrase(passphrase) - - # compute the seed from mnemonic and passphrase root = bip32.from_mnemonic_cardano(storage.get_mnemonic(), passphrase) - keychain = Keychain(root) + path = [HARDENED | 44, HARDENED | 1815] + + # derive the namespaced root node + for i in path: + root.derive_cardano(i) + + keychain = Keychain(path, root) return keychain diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 2578352ba..ea6fa96d2 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -4,45 +4,78 @@ from trezor.crypto import bip32, bip39 from apps.common import cache, storage from apps.common.request_passphrase import protect_by_passphrase -_DEFAULT_CURVE = "secp256k1" +allow = list class Keychain: - def __init__(self, seed: bytes): - self.seed = seed + """ + Keychain provides an API for deriving HD keys from previously allowed + key-spaces. + """ - def derive(self, path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode: - node = bip32.from_seed(self.seed, curve_name) - node.derive_path(path) + def __init__(self, paths: list, roots: list): + self.paths = paths + self.roots = roots + + def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: + # find the root node + root_index = 0 + for curve, *path in self.paths: + prefix = node_path[: len(path)] + suffix = node_path[len(path) :] + if curve == curve_name and path == prefix: + break + root_index += 1 + else: + raise wire.DataError("Forbidden key path") + # derive child node from the root + node = self.roots[root_index].clone() + node.derive_path(suffix) return node -async def get_keychain(ctx: wire.Context) -> Keychain: +async def get_keychain(ctx: wire.Context, paths: list = None) -> Keychain: if not storage.is_initialized(): - # device does not have any seed raise wire.ProcessError("Device is not initialized") seed = cache.get_seed() if seed is None: - # acquire passphrase + # derive seed from mnemonic and passphrase passphrase = cache.get_passphrase() if passphrase is None: passphrase = await protect_by_passphrase(ctx) cache.set_passphrase(passphrase) - - # compute the seed from mnemonic and passphrase seed = bip39.seed(storage.get_mnemonic(), passphrase) cache.set_seed(seed) - keychain = Keychain(seed) + if paths is None: + # allow the whole keyspace by default + paths = [ + ["curve25519"], + ["ed25519"], + ["ed25519-keccak"], + ["nist256p1"], + ["secp256k1"], + ["secp256k1-decred"], + ["secp256k1-groestl"], + ["secp256k1-smart"], + ] + + # derive namespaced root nodes + roots = [] + for curve, *path in paths: + node = bip32.from_seed(seed, curve) + node.derive_path(path) + roots.append(node) + + keychain = Keychain(paths, roots) return keychain def derive_node_without_passphrase( - path: list, curve_name: str = _DEFAULT_CURVE + path: list, curve_name: str = "secp256k1" ) -> bip32.HDNode: if not storage.is_initialized(): - # device does not have any seed raise Exception("Device is not initialized") seed = bip39.seed(storage.get_mnemonic(), "") node = bip32.from_seed(seed, curve_name) diff --git a/tests/test_apps.cardano.address.py b/tests/test_apps.cardano.address.py index 23cac3f62..ee6fd3e2a 100644 --- a/tests/test_apps.cardano.address.py +++ b/tests/test_apps.cardano.address.py @@ -8,6 +8,7 @@ from apps.cardano.address import ( validate_full_path, derive_address_and_node ) +from apps.cardano.seed import Keychain from trezor.crypto import bip32 @@ -16,6 +17,9 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) addresses = [ "Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j", @@ -25,7 +29,7 @@ class TestCardanoAddress(unittest.TestCase): for i, expected in enumerate(addresses): # 44'/1815'/0'/0/i' - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) self.assertEqual(expected, address) nodes = [ @@ -50,7 +54,7 @@ class TestCardanoAddress(unittest.TestCase): ] for i, (priv, ext, pub, chain) in enumerate(nodes): - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) @@ -60,6 +64,9 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) addresses = [ "Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ", @@ -69,7 +76,7 @@ class TestCardanoAddress(unittest.TestCase): for i, expected in enumerate(addresses): # 44'/1815'/0'/0/i - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) self.assertEqual(address, expected) nodes = [ @@ -94,7 +101,7 @@ class TestCardanoAddress(unittest.TestCase): ] for i, (priv, ext, pub, chain) in enumerate(nodes): - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) @@ -105,9 +112,12 @@ class TestCardanoAddress(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) # 44'/1815' - address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) + address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815]) self.assertEqual(address, "Ae2tdPwUPEZ2FGHX3yCKPSbSgyuuTYgMxNq652zKopxT4TuWvEd8Utd92w3") priv, ext, pub, chain = ( @@ -117,7 +127,7 @@ class TestCardanoAddress(unittest.TestCase): b"02ac67c59a8b0264724a635774ca2c242afa10d7ab70e2bf0a8f7d4bb10f1f7a" ) - _, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) + _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815]) self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) diff --git a/tests/test_apps.cardano.get_public_key.py b/tests/test_apps.cardano.get_public_key.py index e9b7fdfa4..33cdc4e8f 100644 --- a/tests/test_apps.cardano.get_public_key.py +++ b/tests/test_apps.cardano.get_public_key.py @@ -1,5 +1,6 @@ from common import * +from apps.cardano.seed import Keychain from apps.cardano.get_public_key import _get_public_key from trezor.crypto import bip32 from ubinascii import hexlify @@ -10,6 +11,9 @@ class TestCardanoGetPublicKey(unittest.TestCase): mnemonic = "all all all all all all all all all all all all" passphrase = "" node = bip32.from_mnemonic_cardano(mnemonic, passphrase) + node.derive_cardano(0x80000000 | 44) + node.derive_cardano(0x80000000 | 1815) + keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node) derivation_paths = [ [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000], @@ -40,7 +44,7 @@ class TestCardanoGetPublicKey(unittest.TestCase): ] for index, derivation_path in enumerate(derivation_paths): - key = _get_public_key(node, derivation_path) + key = _get_public_key(keychain, derivation_path) self.assertEqual(hexlify(key.node.public_key), public_keys[index]) self.assertEqual(hexlify(key.node.chain_code), chain_codes[index])