diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 6a2c084a7..57c5cbfce 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -28,7 +28,7 @@ class Keychain: def validate_path(self, checked_path: list, checked_curve: str): for curve, *path in self.namespaces: if path == checked_path[: len(path)] and curve == checked_curve: - if curve == "ed25519" and not _path_hardened(checked_path): + if "ed25519" in curve and not _path_hardened(checked_path): break return raise wire.DataError("Forbidden key path") diff --git a/tests/test_apps.common.seed.py b/tests/test_apps.common.seed.py index 0a3eeee22..f8b309735 100644 --- a/tests/test_apps.common.seed.py +++ b/tests/test_apps.common.seed.py @@ -33,6 +33,25 @@ class TestKeychain(unittest.TestCase): with self.assertRaises(wire.DataError): k.validate_path(*f) + def test_validate_path_special_ed25519(self): + n = [ + ["ed25519-keccak", 44 | HARDENED, 134 | HARDENED], + ] + k = Keychain(b"", n) + + correct = ( + ([44 | HARDENED, 134 | HARDENED], "ed25519-keccak"), + ) + for c in correct: + self.assertEqual(None, k.validate_path(*c)) + + fails = [ + ([44 | HARDENED, 134 | HARDENED, 1], "ed25519-keccak"), + ] + for f in fails: + with self.assertRaises(wire.DataError): + k.validate_path(*f) + def test_validate_path_empty_namespace(self): k = Keychain(b"", [["secp256k1"]]) correct = (