1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-09 23:11:10 +00:00
trezor-firmware/core/tests/test_apps.common.keychain.py

177 lines
5.6 KiB
Python
Raw Normal View History

# flake8: noqa: F403,F405
from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache, cache_codec, cache_common
from trezor import wire
from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema
class TestKeychain(unittest.TestCase):
def setUpClass(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
def tearDownClass(self):
context.CURRENT_CONTEXT = None
def setUp(self):
cache_codec.start_session()
def tearDown(self):
cache.clear_all()
@mock_storage
def test_verify_path(self):
schemas = (
PathSchema.parse("m/44'/coin_type'", slip44_id=134),
PathSchema.parse("m/44'/coin_type'", slip44_id=11),
)
keychain = Keychain(b"", "secp256k1", schemas)
correct = (
[H_(44), H_(134)],
[H_(44), H_(11)],
)
for path in correct:
keychain.verify_path(path)
fails = (
[H_(44), 134], # path does not match
[44, 134], # path does not match (non-hardened items)
[H_(44), H_(13)], # invalid second item
)
for f in fails:
with self.assertRaises(wire.DataError):
keychain.verify_path(f)
# turn off restrictions
safety_checks.apply_setting(SafetyCheckLevel.PromptTemporarily)
for path in correct + fails:
keychain.verify_path(path)
def test_verify_path_special_ed25519(self):
schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134)
k = Keychain(b"", "ed25519-keccak", [schema])
# OK case
k.verify_path([H_(44), H_(134)])
# failing case: non-hardened component with ed25519-like derivation
with self.assertRaises(wire.DataError):
k.verify_path([H_(44), H_(134), 1])
def test_no_schemas(self):
k = Keychain(b"", "secp256k1", [])
paths = (
[],
[1, 2, 3, 4],
[H_(44), H_(11)],
[H_(44), H_(11), 12],
)
for path in paths:
self.assertRaises(wire.DataError, k.verify_path, path)
def test_get_keychain(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(get_keychain("secp256k1", [schema]))
# valid path:
self.assertIsNotNone(keychain.derive([H_(44), H_(1)]))
# invalid path:
with self.assertRaises(wire.DataError):
keychain.derive([44])
def test_with_slip44(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
slip44_id = 42
valid_path = [H_(44), H_(slip44_id), H_(0)]
invalid_path = [H_(44), H_(99), H_(0)]
testnet_path = [H_(44), H_(1), H_(0)]
def check_valid_paths(keychain, *paths):
for path in paths:
self.assertIsNotNone(keychain.derive(path))
def check_invalid_paths(keychain, *paths):
for path in paths:
self.assertRaises(wire.DataError, keychain.derive, path)
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id)
async def func_id_only(msg, keychain):
check_valid_paths(keychain, valid_path, testnet_path)
check_invalid_paths(keychain, invalid_path)
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, allow_testnet=False)
async def func_disallow_testnet(msg, keychain):
check_valid_paths(keychain, valid_path)
check_invalid_paths(keychain, testnet_path, invalid_path)
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, curve="ed25519")
async def func_with_curve(msg, keychain):
self.assertEqual(keychain.curve, "ed25519")
check_valid_paths(keychain, valid_path, testnet_path)
check_invalid_paths(keychain, invalid_path)
await_result(func_id_only(None))
await_result(func_disallow_testnet(None))
await_result(func_with_curve(None))
def test_lru_cache(self):
class Deletable:
def __init__(self):
self.deleted = False
def __del__(self):
self.deleted = True
cache = LRUCache(10)
obj_a = Deletable()
self.assertIsNone(cache.get("a"))
cache.insert("a", obj_a)
self.assertIs(cache.get("a"), obj_a)
# test eviction
objects = [(i, Deletable()) for i in range(10)]
for key, obj in objects:
cache.insert(key, obj)
# object A should have been evicted
self.assertIsNone(cache.get("a"))
self.assertTrue(obj_a.deleted)
cache.insert("a", obj_a)
for key, obj in objects[:-1]:
# objects should have been evicted in insertion order
self.assertIsNone(cache.get(key))
self.assertTrue(obj.deleted)
cache.insert(key, obj)
# use "a" object
self.assertIs(cache.get("a"), obj_a)
# insert last object
key, obj = objects[-1]
cache.insert(key, obj)
# "a" is recently used so should not be evicted now
self.assertIs(cache.get("a"), obj_a)
if __name__ == "__main__":
unittest.main()