1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-28 00:58:09 +00:00

seed: use lazy seed derivation, wipe after the workflow ends

This commit is contained in:
Jan Pochyla 2018-12-13 13:53:53 +01:00
parent a0df8c74d5
commit 7730533dde
8 changed files with 44 additions and 54 deletions

View File

@ -13,14 +13,22 @@ class Keychain:
key-spaces. key-spaces.
""" """
def __init__(self, paths: list, roots: list): def __init__(self, seed: bytes, namespaces: list):
self.paths = paths self.seed = seed
self.roots = roots self.namespaces = namespaces
self.roots = [None] * len(namespaces)
def __del__(self):
for root in self.roots:
if root is not None:
root.__del__()
del self.roots
del self.seed
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node # find the root node index
root_index = 0 root_index = 0
for curve, *path in self.paths: for curve, *path in self.namespaces:
prefix = node_path[: len(path)] prefix = node_path[: len(path)]
suffix = node_path[len(path) :] suffix = node_path[len(path) :]
if curve == curve_name and path == prefix: if curve == curve_name and path == prefix:
@ -28,13 +36,21 @@ class Keychain:
root_index += 1 root_index += 1
else: else:
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
# create the root node if not cached
root = self.roots[root_index]
if root is None:
root = bip32.from_seed(self.seed, curve_name)
root.derive_path(path)
self.roots[root_index] = root
# derive child node from the root # derive child node from the root
node = self.roots[root_index].clone() node = root.clone()
node.derive_path(suffix) node.derive_path(suffix)
return node return node
async def get_keychain(ctx: wire.Context, paths: list) -> Keychain: async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.ProcessError("Device is not initialized") raise wire.ProcessError("Device is not initialized")
@ -48,14 +64,7 @@ async def get_keychain(ctx: wire.Context, paths: list) -> Keychain:
seed = bip39.seed(storage.get_mnemonic(), passphrase) seed = bip39.seed(storage.get_mnemonic(), passphrase)
cache.set_seed(seed) cache.set_seed(seed)
# derive namespaced root nodes keychain = Keychain(seed, namespaces)
roots = []
for curve, *path in paths:
node = bip32.from_seed(seed, curve)
node.derive_path(path)
roots.append(node)
keychain = Keychain(paths, roots)
return keychain return keychain

View File

@ -184,7 +184,10 @@ async def protobuf_workflow(ctx, reader, handler, *args):
async def keychain_workflow(ctx, req, namespace, handler, *args): async def keychain_workflow(ctx, req, namespace, handler, *args):
keychain = await seed.get_keychain(ctx, namespace) keychain = await seed.get_keychain(ctx, namespace)
args += (keychain,) args += (keychain,)
return await handler(ctx, req, *args) try:
return await handler(ctx, req, *args)
finally:
keychain.__del__()
def import_workflow(ctx, req, pkgname, modname, *args): def import_workflow(ctx, req, pkgname, modname, *args):

View File

@ -1,7 +1,7 @@
from common import * from common import *
from trezor.utils import chunks from trezor.utils import chunks
from trezor.crypto import bip32, bip39 from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
@ -25,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
def test_send_native_p2wpkh(self): def test_send_native_p2wpkh(self):
coin = coins.by_name('Testnet') coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s # 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s
@ -114,7 +112,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
@ -124,9 +122,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
def test_send_native_p2wpkh_change(self): def test_send_native_p2wpkh_change(self):
coin = coins.by_name('Testnet') coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s # 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s
@ -211,7 +207,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)

View File

@ -1,7 +1,7 @@
from common import * from common import *
from trezor.utils import chunks from trezor.utils import chunks
from trezor.crypto import bip32, bip39 from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
@ -25,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
def test_send_native_p2wpkh(self): def test_send_native_p2wpkh(self):
coin = coins.by_name('Groestlcoin Testnet') coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType( inp1 = TxInputType(
# 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja # 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja
@ -111,7 +109,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
@ -121,9 +119,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
def test_send_native_p2wpkh_change(self): def test_send_native_p2wpkh_change(self):
coin = coins.by_name('Groestlcoin Testnet') coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType( inp1 = TxInputType(
# 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja # 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja
@ -205,7 +201,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)

View File

@ -1,7 +1,7 @@
from common import * from common import *
from trezor.utils import chunks from trezor.utils import chunks
from trezor.crypto import bip32, bip39 from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
@ -25,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
def test_send_p2wpkh_in_p2sh(self): def test_send_p2wpkh_in_p2sh(self):
coin = coins.by_name('Testnet') coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -111,7 +109,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
@ -121,9 +119,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
def test_send_p2wpkh_in_p2sh_change(self): def test_send_p2wpkh_in_p2sh_change(self):
coin = coins.by_name('Testnet') coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -215,7 +211,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
@ -227,9 +223,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
def test_send_p2wpkh_in_p2sh_attack_amount(self): def test_send_p2wpkh_in_p2sh_attack_amount(self):
coin = coins.by_name('Testnet') coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -325,7 +319,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
TxRequest(request_type=TXFINISHED, details=None) TxRequest(request_type=TXFINISHED, details=None)
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
i = 0 i = 0
messages_count = int(len(messages) / 2) messages_count = int(len(messages) / 2)

View File

@ -1,7 +1,7 @@
from common import * from common import *
from trezor.utils import chunks from trezor.utils import chunks
from trezor.crypto import bip32, bip39 from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
@ -25,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
def test_send_p2wpkh_in_p2sh(self): def test_send_p2wpkh_in_p2sh(self):
coin = coins.by_name('Groestlcoin Testnet') coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7
@ -111,7 +109,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
@ -121,9 +119,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
def test_send_p2wpkh_in_p2sh_change(self): def test_send_p2wpkh_in_p2sh_change(self):
coin = coins.by_name('Groestlcoin Testnet') coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType( inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7 # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7
@ -214,7 +210,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)), )),
] ]
keychain = Keychain([[coin.curve_name]], [root]) keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)

View File

@ -97,9 +97,7 @@ class TestSignTx(unittest.TestCase):
] ]
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1') keychain = Keychain(seed, [[coin_bitcoin.curve_name]])
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):

View File

@ -1,7 +1,7 @@
from common import * from common import *
from trezor.utils import chunks from trezor.utils import chunks
from trezor.crypto import bip32, bip39 from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputType import TxOutputType
@ -86,9 +86,7 @@ class TestSignTx_GRS(unittest.TestCase):
] ]
seed = bip39.seed(' '.join(['all'] * 12), '') seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name) keychain = Keychain(seed, [[coin.curve_name]])
keychain = Keychain([[coin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain) signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)