diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py new file mode 100644 index 000000000..dce92a155 --- /dev/null +++ b/core/tests/test_apps.common.keychain.py @@ -0,0 +1,165 @@ +from common import * + +from mock_storage import mock_storage + +from storage import cache +import storage.device +from apps.common import HARDENED +from apps.common.paths import path_is_hardened +from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain +from trezor import wire +from trezor.crypto import bip39 + + +class TestKeychain(unittest.TestCase): + @mock_storage + def test_verify_path(self): + n = [ + [44 | HARDENED, 134 | HARDENED], + [44 | HARDENED, 11 | HARDENED], + ] + keychain = Keychain(b"", "secp256k1", n) + + correct = ( + [44 | HARDENED, 134 | HARDENED], + [44 | HARDENED, 11 | HARDENED], + [44 | HARDENED, 11 | HARDENED, 12], + ) + for path in correct: + keychain.verify_path(path) + + fails = ( + [44 | HARDENED, 134], # path does not match + [44, 134], # path does not match (non-hardened items) + [44 | HARDENED, 13 | HARDENED], # invalid second item + ) + for f in fails: + with self.assertRaises(wire.DataError): + keychain.verify_path(f) + + # turn off restrictions + storage.device.set_unsafe_prompts_allowed(True) + for path in correct + fails: + keychain.verify_path(path) + + def test_verify_path_special_ed25519(self): + n = [[44 | HARDENED, 134 | HARDENED]] + k = Keychain(b"", "ed25519-keccak", n) + + # OK case + k.verify_path([44 | HARDENED, 134 | HARDENED]) + + # failing case: non-hardened component with ed25519-like derivation + with self.assertRaises(wire.DataError): + k.verify_path([44 | HARDENED, 134 | HARDENED, 1]) + + def test_verify_path_empty_namespace(self): + k = Keychain(b"", "secp256k1", [[]]) + correct = ( + [], + [1, 2, 3, 4], + [44 | HARDENED, 11 | HARDENED], + [44 | HARDENED, 11 | HARDENED, 12], + ) + for c in correct: + k.verify_path(c) + + def test_get_keychain(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + cache.start_session() + cache.set(cache.APP_COMMON_SEED, seed) + + namespaces = [[44 | HARDENED]] + keychain = await_result( + get_keychain(wire.DUMMY_CONTEXT, "secp256k1", namespaces) + ) + + # valid path: + self.assertIsNotNone(keychain.derive([44 | HARDENED, 1 | HARDENED])) + + # invalid path: + with self.assertRaises(wire.DataError): + keychain.derive([44]) + + def test_with_slip44(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + cache.start_session() + cache.set(cache.APP_COMMON_SEED, seed) + + slip44_id = 42 + valid_path = [44 | HARDENED, slip44_id | HARDENED] + invalid_path = [44 | HARDENED, 99 | HARDENED] + testnet_path = [44 | HARDENED, 1 | HARDENED] + + def check_valid_paths(keychain, *paths): + for path in paths: + self.assertIsNotNone(keychain.derive(path)) + + def check_invalid_paths(keychain, *paths): + for path in paths: + self.assertRaises(wire.DataError, keychain.derive, path) + + @with_slip44_keychain(slip44_id) + async def func_id_only(ctx, msg, keychain): + check_valid_paths(keychain, valid_path) + check_invalid_paths(keychain, testnet_path, invalid_path) + + @with_slip44_keychain(slip44_id, allow_testnet=True) + async def func_allow_testnet(ctx, msg, keychain): + check_valid_paths(keychain, valid_path, testnet_path) + check_invalid_paths(keychain, invalid_path) + + @with_slip44_keychain(slip44_id, curve="ed25519") + async def func_with_curve(ctx, msg, keychain): + self.assertEqual(keychain.curve, "ed25519") + check_valid_paths(keychain, valid_path) + check_invalid_paths(keychain, testnet_path, invalid_path) + + await_result(func_id_only(wire.DUMMY_CONTEXT, None)) + await_result(func_allow_testnet(wire.DUMMY_CONTEXT, None)) + await_result(func_with_curve(wire.DUMMY_CONTEXT, None)) + + def test_lru_cache(self): + class Deletable: + def __init__(self): + self.deleted = False + + def __del__(self): + self.deleted = True + + cache = LRUCache(10) + + obj_a = Deletable() + self.assertIsNone(cache.get("a")) + cache.insert("a", obj_a) + + self.assertIs(cache.get("a"), obj_a) + + # test eviction + objects = [(i, Deletable()) for i in range(10)] + for key, obj in objects: + cache.insert(key, obj) + + # object A should have been evicted + self.assertIsNone(cache.get("a")) + self.assertTrue(obj_a.deleted) + + cache.insert("a", obj_a) + for key, obj in objects[:-1]: + # objects should have been evicted in insertion order + self.assertIsNone(cache.get(key)) + self.assertTrue(obj.deleted) + cache.insert(key, obj) + + # use "a" object + self.assertIs(cache.get("a"), obj_a) + # insert last object + key, obj = objects[-1] + cache.insert(key, obj) + + # "a" is recently used so should not be evicted now + self.assertIs(cache.get("a"), obj_a) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_apps.common.paths.py b/core/tests/test_apps.common.paths.py index 431e82c3c..1b85e00f3 100644 --- a/core/tests/test_apps.common.paths.py +++ b/core/tests/test_apps.common.paths.py @@ -1,6 +1,6 @@ from common import * from apps.common import HARDENED -from apps.common.paths import validate_path_for_get_public_key, is_hardened +from apps.common.paths import validate_path_for_get_public_key, is_hardened, path_is_hardened class TestPaths(unittest.TestCase): @@ -14,6 +14,14 @@ class TestPaths(unittest.TestCase): self.assertFalse(is_hardened(0)) self.assertFalse(is_hardened(99999)) + def test_path_is_hardened(self): + self.assertTrue(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED])) + self.assertTrue(path_is_hardened([0 | HARDENED, ])) + + self.assertFalse(path_is_hardened([44, 44 | HARDENED, 0 | HARDENED])) + self.assertFalse(path_is_hardened([0, ])) + self.assertFalse(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0])) + def test_path_for_get_public_key(self): # 44'/41'/0' self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED], 41)) diff --git a/core/tests/test_apps.common.seed.py b/core/tests/test_apps.common.seed.py index cee415652..a1190ebc7 100644 --- a/core/tests/test_apps.common.seed.py +++ b/core/tests/test_apps.common.seed.py @@ -1,83 +1,16 @@ from common import * -from storage import cache -from apps.common import HARDENED -from apps.common.seed import Keychain, Slip21Node, _path_hardened, get_keychain, with_slip44_keychain +from apps.common.keychain import Keychain +from apps.common.seed import Slip21Node from trezor import wire from trezor.crypto import bip39 -class TestKeychain(unittest.TestCase): - - def test_match_path(self): - n = [ - ("ed25519", [44 | HARDENED, 134 | HARDENED]), - ("secp256k1", [44 | HARDENED, 11 | HARDENED]), - ] - k = Keychain(b"", n) - - correct = ( - ([44 | HARDENED, 134 | HARDENED], "ed25519"), - ([44 | HARDENED, 11 | HARDENED], "secp256k1"), - ([44 | HARDENED, 11 | HARDENED, 12], "secp256k1"), - ) - for path, curve in correct: - i, suffix = k.match_path(path) - ns_curve, ns = k.namespaces[i] - self.assertEqual(curve, ns_curve) - - fails = [ - [44 | HARDENED, 134], # path does not match - [44, 134], # path does not match (non-hardened items) - [44 | HARDENED, 134 | HARDENED, 123], # non-hardened item in ed25519 ns - [44 | HARDENED, 13 | HARDENED], # invalid second item - ] - for f in fails: - with self.assertRaises(wire.DataError): - k.match_path(f) - - def test_match_path_special_ed25519(self): - n = [ - ("ed25519-keccak", [44 | HARDENED, 134 | HARDENED]), - ] - k = Keychain(b"", n) - - correct = ( - [44 | HARDENED, 134 | HARDENED], - ) - for c in correct: - k.match_path(c) - - fails = [ - [44 | HARDENED, 134 | HARDENED, 1], - ] - for f in fails: - with self.assertRaises(wire.DataError): - k.match_path(f) - - def test_match_path_empty_namespace(self): - k = Keychain(b"", [("secp256k1", [])]) - correct = ( - [], - [1, 2, 3, 4], - [44 | HARDENED, 11 | HARDENED], - [44 | HARDENED, 11 | HARDENED, 12], - ) - for c in correct: - k.match_path(c) - - def test_path_hardened(self): - self.assertTrue(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED])) - self.assertTrue(_path_hardened([0 | HARDENED, ])) - - self.assertFalse(_path_hardened([44, 44 | HARDENED, 0 | HARDENED])) - self.assertFalse(_path_hardened([0, ])) - self.assertFalse(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0])) - +class TestSeed(unittest.TestCase): def test_slip21(self): seed = bip39.seed(' '.join(['all'] * 12), '') node1 = Slip21Node(seed) node2 = node1.clone() - keychain = Keychain(seed, [("slip21", [b"SLIP-0021"])]) + keychain = Keychain(seed, "", [], slip21_namespaces=[[b"SLIP-0021"]]) # Key(m) KEY_M = unhexlify(b"dbf12b44133eaab506a740f6565cc117228cbf1dd70635cfa8ddfdc9af734756") @@ -87,87 +20,25 @@ class TestKeychain(unittest.TestCase): KEY_M_SLIP0021 = unhexlify(b"1d065e3ac1bbe5c7fad32cf2305f7d709dc070d672044a19e610c77cdf33de0d") node1.derive_path([b"SLIP-0021"]) self.assertEqual(node1.key(), KEY_M_SLIP0021) - keychain.match_path([b"SLIP-0021"]) - self.assertEqual(keychain.derive([b"SLIP-0021"]).key(), KEY_M_SLIP0021) + self.assertEqual(keychain.derive_slip21([b"SLIP-0021"]).key(), KEY_M_SLIP0021) # Key(m/"SLIP-0021"/"Master encryption key") KEY_M_SLIP0021_MEK = unhexlify(b"ea163130e35bbafdf5ddee97a17b39cef2be4b4f390180d65b54cf05c6a82fde") node1.derive_path([b"Master encryption key"]) self.assertEqual(node1.key(), KEY_M_SLIP0021_MEK) - keychain.match_path([b"SLIP-0021", b"Master encryption key"]) - self.assertEqual(keychain.derive([b"SLIP-0021", b"Master encryption key"]).key(), KEY_M_SLIP0021_MEK) + self.assertEqual(keychain.derive_slip21([b"SLIP-0021", b"Master encryption key"]).key(), KEY_M_SLIP0021_MEK) # Key(m/"SLIP-0021"/"Authentication key") KEY_M_SLIP0021_AK = unhexlify(b"47194e938ab24cc82bfa25f6486ed54bebe79c40ae2a5a32ea6db294d81861a6") node2.derive_path([b"SLIP-0021", b"Authentication key"]) self.assertEqual(node2.key(), KEY_M_SLIP0021_AK) - keychain.match_path([b"SLIP-0021", b"Authentication key"]) - self.assertEqual(keychain.derive([b"SLIP-0021", b"Authentication key"]).key(), KEY_M_SLIP0021_AK) + self.assertEqual(keychain.derive_slip21([b"SLIP-0021", b"Authentication key"]).key(), KEY_M_SLIP0021_AK) # Forbidden paths. with self.assertRaises(wire.DataError): - self.assertFalse(keychain.match_path([])) - with self.assertRaises(wire.DataError): - self.assertFalse(keychain.match_path([b"SLIP-9999", b"Authentication key"])) - with self.assertRaises(wire.DataError): - keychain.derive([b"SLIP-9999", b"Authentication key"]).key() - - def test_get_keychain(self): - seed = bip39.seed(' '.join(['all'] * 12), '') - cache.start_session() - cache.set(cache.APP_COMMON_SEED, seed) - - namespaces = [("secp256k1", [44 | HARDENED])] - keychain = await_result(get_keychain(wire.DUMMY_CONTEXT, namespaces)) - - # valid path: - self.assertIsNotNone(keychain.derive([44 | HARDENED, 1 | HARDENED])) - - # invalid path: + keychain.derive_slip21([]) with self.assertRaises(wire.DataError): - keychain.derive([44]) - - def test_with_slip44(self): - seed = bip39.seed(' '.join(['all'] * 12), '') - cache.start_session() - cache.set(cache.APP_COMMON_SEED, seed) - - slip44_id = 42 - valid_path = [44 | HARDENED, slip44_id | HARDENED] - invalid_path = [44 | HARDENED, 99 | HARDENED] - testnet_path = [44 | HARDENED, 1 | HARDENED] - - def check_valid_paths(keychain, *paths): - for path in paths: - self.assertIsNotNone(keychain.derive(path)) - - def check_invalid_paths(keychain, *paths): - for path in paths: - self.assertRaises(wire.DataError, keychain.derive, path) - - @with_slip44_keychain(slip44_id) - async def func_id_only(ctx, msg, keychain): - check_valid_paths(keychain, valid_path) - check_invalid_paths(keychain, testnet_path, invalid_path) - - @with_slip44_keychain(slip44_id, allow_testnet=True) - async def func_allow_testnet(ctx, msg, keychain): - check_valid_paths(keychain, valid_path, testnet_path) - check_invalid_paths(keychain, invalid_path) - - @with_slip44_keychain(slip44_id, curve="ed25519") - async def func_with_curve(ctx, msg, keychain): - check_valid_paths(keychain, valid_path) - check_invalid_paths(keychain, testnet_path, invalid_path) - - i, _ = keychain.match_path(valid_path) - ns_curve, ns = keychain.namespaces[i] - self.assertEqual(ns_curve, "ed25519") - - await_result(func_id_only(wire.DUMMY_CONTEXT, None)) - await_result(func_allow_testnet(wire.DUMMY_CONTEXT, None)) - await_result(func_with_curve(wire.DUMMY_CONTEXT, None)) - + keychain.derive_slip21([b"SLIP-9999", b"Authentication key"]) if __name__ == '__main__':