1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-01 20:08:26 +00:00

seed: add support for key namespaces

This commit is contained in:
Jan Pochyla 2018-11-14 19:58:07 +01:00
parent 5bc47fc567
commit 9ecd123bd5
4 changed files with 86 additions and 36 deletions

View File

@ -6,36 +6,39 @@ from apps.common.request_passphrase import protect_by_passphrase
class Keychain: class Keychain:
def __init__(self, root: bip32.HDNode): def __init__(self, path: list, root: bip32.HDNode):
self.path = path
self.root = root self.root = root
def derive(self, path: list) -> bip32.HDNode: def derive(self, node_path: list) -> bip32.HDNode:
self.validate_path(path) # check we are in the cardano namespace
prefix = node_path[: len(self.path)]
suffix = node_path[len(self.path) :]
if prefix != self.path:
raise wire.DataError("Forbidden key path")
# derive child node from the root
node = self.root.clone() node = self.root.clone()
for i in path: for i in suffix:
node.derive_cardano(i) node.derive_cardano(i)
return node return node
def validate_path(self, path: list) -> None:
if len(path) < 2 or len(path) > 5:
raise wire.ProcessError("Derivation path must be composed from 2-5 indices")
if path[0] != HARDENED | 44 or path[1] != HARDENED | 1815:
raise wire.ProcessError("This is not cardano derivation path")
async def get_keychain(ctx: wire.Context) -> Keychain: async def get_keychain(ctx: wire.Context) -> Keychain:
if not storage.is_initialized(): if not storage.is_initialized():
# device does not have any seed
raise wire.ProcessError("Device is not initialized") raise wire.ProcessError("Device is not initialized")
# acquire passphrase # derive the root node from mnemonic and passphrase
passphrase = cache.get_passphrase() passphrase = cache.get_passphrase()
if passphrase is None: if passphrase is None:
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase) cache.set_passphrase(passphrase)
# compute the seed from mnemonic and passphrase
root = bip32.from_mnemonic_cardano(storage.get_mnemonic(), passphrase) root = bip32.from_mnemonic_cardano(storage.get_mnemonic(), passphrase)
keychain = Keychain(root) path = [HARDENED | 44, HARDENED | 1815]
# derive the namespaced root node
for i in path:
root.derive_cardano(i)
keychain = Keychain(path, root)
return keychain return keychain

View File

@ -4,45 +4,78 @@ from trezor.crypto import bip32, bip39
from apps.common import cache, storage from apps.common import cache, storage
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
_DEFAULT_CURVE = "secp256k1" allow = list
class Keychain: class Keychain:
def __init__(self, seed: bytes): """
self.seed = seed Keychain provides an API for deriving HD keys from previously allowed
key-spaces.
"""
def derive(self, path: list, curve_name: str = _DEFAULT_CURVE) -> bip32.HDNode: def __init__(self, paths: list, roots: list):
node = bip32.from_seed(self.seed, curve_name) self.paths = paths
node.derive_path(path) self.roots = roots
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node
root_index = 0
for curve, *path in self.paths:
prefix = node_path[: len(path)]
suffix = node_path[len(path) :]
if curve == curve_name and path == prefix:
break
root_index += 1
else:
raise wire.DataError("Forbidden key path")
# derive child node from the root
node = self.roots[root_index].clone()
node.derive_path(suffix)
return node return node
async def get_keychain(ctx: wire.Context) -> Keychain: async def get_keychain(ctx: wire.Context, paths: list = None) -> Keychain:
if not storage.is_initialized(): if not storage.is_initialized():
# device does not have any seed
raise wire.ProcessError("Device is not initialized") raise wire.ProcessError("Device is not initialized")
seed = cache.get_seed() seed = cache.get_seed()
if seed is None: if seed is None:
# acquire passphrase # derive seed from mnemonic and passphrase
passphrase = cache.get_passphrase() passphrase = cache.get_passphrase()
if passphrase is None: if passphrase is None:
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase) cache.set_passphrase(passphrase)
# compute the seed from mnemonic and passphrase
seed = bip39.seed(storage.get_mnemonic(), passphrase) seed = bip39.seed(storage.get_mnemonic(), passphrase)
cache.set_seed(seed) cache.set_seed(seed)
keychain = Keychain(seed) if paths is None:
# allow the whole keyspace by default
paths = [
["curve25519"],
["ed25519"],
["ed25519-keccak"],
["nist256p1"],
["secp256k1"],
["secp256k1-decred"],
["secp256k1-groestl"],
["secp256k1-smart"],
]
# derive namespaced root nodes
roots = []
for curve, *path in paths:
node = bip32.from_seed(seed, curve)
node.derive_path(path)
roots.append(node)
keychain = Keychain(paths, roots)
return keychain return keychain
def derive_node_without_passphrase( def derive_node_without_passphrase(
path: list, curve_name: str = _DEFAULT_CURVE path: list, curve_name: str = "secp256k1"
) -> bip32.HDNode: ) -> bip32.HDNode:
if not storage.is_initialized(): if not storage.is_initialized():
# device does not have any seed
raise Exception("Device is not initialized") raise Exception("Device is not initialized")
seed = bip39.seed(storage.get_mnemonic(), "") seed = bip39.seed(storage.get_mnemonic(), "")
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)

View File

@ -8,6 +8,7 @@ from apps.cardano.address import (
validate_full_path, validate_full_path,
derive_address_and_node derive_address_and_node
) )
from apps.cardano.seed import Keychain
from trezor.crypto import bip32 from trezor.crypto import bip32
@ -16,6 +17,9 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all" mnemonic = "all all all all all all all all all all all all"
passphrase = "" passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
addresses = [ addresses = [
"Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j", "Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j",
@ -25,7 +29,7 @@ class TestCardanoAddress(unittest.TestCase):
for i, expected in enumerate(addresses): for i, expected in enumerate(addresses):
# 44'/1815'/0'/0/i' # 44'/1815'/0'/0/i'
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
self.assertEqual(expected, address) self.assertEqual(expected, address)
nodes = [ nodes = [
@ -50,7 +54,7 @@ class TestCardanoAddress(unittest.TestCase):
] ]
for i, (priv, ext, pub, chain) in enumerate(nodes): for i, (priv, ext, pub, chain) in enumerate(nodes):
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i]) _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)
@ -60,6 +64,9 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all" mnemonic = "all all all all all all all all all all all all"
passphrase = "" passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
addresses = [ addresses = [
"Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ", "Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ",
@ -69,7 +76,7 @@ class TestCardanoAddress(unittest.TestCase):
for i, expected in enumerate(addresses): for i, expected in enumerate(addresses):
# 44'/1815'/0'/0/i # 44'/1815'/0'/0/i
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
self.assertEqual(address, expected) self.assertEqual(address, expected)
nodes = [ nodes = [
@ -94,7 +101,7 @@ class TestCardanoAddress(unittest.TestCase):
] ]
for i, (priv, ext, pub, chain) in enumerate(nodes): for i, (priv, ext, pub, chain) in enumerate(nodes):
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i]) _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)
@ -105,9 +112,12 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all" mnemonic = "all all all all all all all all all all all all"
passphrase = "" passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
# 44'/1815' # 44'/1815'
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815])
self.assertEqual(address, "Ae2tdPwUPEZ2FGHX3yCKPSbSgyuuTYgMxNq652zKopxT4TuWvEd8Utd92w3") self.assertEqual(address, "Ae2tdPwUPEZ2FGHX3yCKPSbSgyuuTYgMxNq652zKopxT4TuWvEd8Utd92w3")
priv, ext, pub, chain = ( priv, ext, pub, chain = (
@ -117,7 +127,7 @@ class TestCardanoAddress(unittest.TestCase):
b"02ac67c59a8b0264724a635774ca2c242afa10d7ab70e2bf0a8f7d4bb10f1f7a" b"02ac67c59a8b0264724a635774ca2c242afa10d7ab70e2bf0a8f7d4bb10f1f7a"
) )
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815]) _, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815])
self.assertEqual(hexlify(n.private_key()), priv) self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext) self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub) self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)

View File

@ -1,5 +1,6 @@
from common import * from common import *
from apps.cardano.seed import Keychain
from apps.cardano.get_public_key import _get_public_key from apps.cardano.get_public_key import _get_public_key
from trezor.crypto import bip32 from trezor.crypto import bip32
from ubinascii import hexlify from ubinascii import hexlify
@ -10,6 +11,9 @@ class TestCardanoGetPublicKey(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all" mnemonic = "all all all all all all all all all all all all"
passphrase = "" passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase) node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
derivation_paths = [ derivation_paths = [
[0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000], [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000],
@ -40,7 +44,7 @@ class TestCardanoGetPublicKey(unittest.TestCase):
] ]
for index, derivation_path in enumerate(derivation_paths): for index, derivation_path in enumerate(derivation_paths):
key = _get_public_key(node, derivation_path) key = _get_public_key(keychain, derivation_path)
self.assertEqual(hexlify(key.node.public_key), public_keys[index]) self.assertEqual(hexlify(key.node.public_key), public_keys[index])
self.assertEqual(hexlify(key.node.chain_code), chain_codes[index]) self.assertEqual(hexlify(key.node.chain_code), chain_codes[index])