mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-19 04:48:12 +00:00
core: refactor keychain to only support one curve at a time
also make a cleaner distinction between keychain, seed, path This enables using `unsafe_prompts`, because with the original code, if there was no namespace match, we wouldn't know which curve to use. For ease of implementation, we use a LRU cache for derived keys, instead of the original design "one cache entry per namespace". SLIP21 is now treated completely separately, via `slip21_namespaces` and `derive_slip21` method. If more slip21-like things come in the future, we can instead hang them on the keychain: put a per-curve Keychain object accessible by `keychain[curve_name].derive()`, and the majority usecase will just pass around `keychain[curve_name]` instead of having to specify the curve in every `derive()` call. Or alternately we'll just specify the curve in every `derive()` call, whichever seems more appropriate.
This commit is contained in:
parent
fa757f4b7f
commit
ff4ec2185e
@ -7,8 +7,6 @@ from apps.common import mnemonic
|
|||||||
from apps.common.passphrase import get as get_passphrase
|
from apps.common.passphrase import get as get_passphrase
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain
|
from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain
|
||||||
|
|
||||||
|
|
||||||
@ -18,15 +16,14 @@ class Keychain:
|
|||||||
def __init__(self, root: bip32.HDNode) -> None:
|
def __init__(self, root: bip32.HDNode) -> None:
|
||||||
self.root = root
|
self.root = root
|
||||||
|
|
||||||
def match_path(self, path: Bip32Path) -> Tuple[int, Bip32Path]:
|
def verify_path(self, path: Bip32Path) -> None:
|
||||||
if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE:
|
if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE:
|
||||||
raise wire.DataError("Forbidden key path")
|
raise wire.DataError("Forbidden key path")
|
||||||
return 0, path[len(SEED_NAMESPACE) :]
|
|
||||||
|
|
||||||
def derive(self, node_path: Bip32Path) -> bip32.HDNode:
|
def derive(self, node_path: Bip32Path) -> bip32.HDNode:
|
||||||
_, suffix = self.match_path(node_path)
|
|
||||||
# derive child node from the root
|
# derive child node from the root
|
||||||
node = self.root.clone()
|
node = self.root.clone()
|
||||||
|
suffix = node_path[len(SEED_NAMESPACE) :]
|
||||||
for i in suffix:
|
for i in suffix:
|
||||||
node.derive_cardano(i)
|
node.derive_cardano(i)
|
||||||
return node
|
return node
|
||||||
|
181
core/src/apps/common/keychain.py
Normal file
181
core/src/apps/common/keychain.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
from storage import device
|
||||||
|
from trezor import wire
|
||||||
|
from trezor.crypto import bip32
|
||||||
|
|
||||||
|
from . import HARDENED, paths
|
||||||
|
from .seed import Slip21Node, get_seed
|
||||||
|
|
||||||
|
if False:
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
from protobuf import MessageType
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class NodeProtocol(Protocol[paths.PathType]):
|
||||||
|
def derive_path(self, path: paths.PathType) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def clone(self: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
NodeType = TypeVar("NodeType", bound=NodeProtocol)
|
||||||
|
|
||||||
|
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
||||||
|
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
||||||
|
|
||||||
|
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
|
||||||
|
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
|
||||||
|
|
||||||
|
class Deletable(Protocol):
|
||||||
|
def __del__(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path")
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
def __init__(self, size: int) -> None:
|
||||||
|
self.size = size
|
||||||
|
self.cache_keys = [] # type: List[Any]
|
||||||
|
self.cache = {} # type: Dict[Any, Deletable]
|
||||||
|
|
||||||
|
def insert(self, key: Any, value: Deletable) -> None:
|
||||||
|
if key in self.cache_keys:
|
||||||
|
self.cache_keys.remove(key)
|
||||||
|
self.cache_keys.insert(0, key)
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
if len(self.cache_keys) > self.size:
|
||||||
|
dropped_key = self.cache_keys.pop()
|
||||||
|
self.cache[dropped_key].__del__()
|
||||||
|
del self.cache[dropped_key]
|
||||||
|
|
||||||
|
def get(self, key: Any) -> Any:
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.cache_keys.remove(key)
|
||||||
|
self.cache_keys.insert(0, key)
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
for value in self.cache.values():
|
||||||
|
value.__del__()
|
||||||
|
self.cache.clear()
|
||||||
|
self.cache_keys.clear()
|
||||||
|
del self.cache
|
||||||
|
|
||||||
|
|
||||||
|
class Keychain:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seed: bytes,
|
||||||
|
curve: str,
|
||||||
|
namespaces: Sequence[paths.Bip32Path],
|
||||||
|
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
||||||
|
) -> None:
|
||||||
|
self.seed = seed
|
||||||
|
self.curve = curve
|
||||||
|
self.namespaces = namespaces
|
||||||
|
self.slip21_namespaces = slip21_namespaces
|
||||||
|
|
||||||
|
self._cache = LRUCache(10)
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
self._cache.__del__()
|
||||||
|
del self._cache
|
||||||
|
del self.seed
|
||||||
|
|
||||||
|
def verify_path(self, path: paths.Bip32Path) -> None:
|
||||||
|
if "ed25519" in self.curve and not paths.path_is_hardened(path):
|
||||||
|
raise FORBIDDEN_KEY_PATH
|
||||||
|
|
||||||
|
if device.unsafe_prompts_allowed():
|
||||||
|
return
|
||||||
|
|
||||||
|
if any(ns == path[: len(ns)] for ns in self.namespaces):
|
||||||
|
return
|
||||||
|
|
||||||
|
raise FORBIDDEN_KEY_PATH
|
||||||
|
|
||||||
|
def _derive_with_cache(
|
||||||
|
self, prefix_len: int, path: paths.PathType, new_root: Callable[[], NodeType],
|
||||||
|
) -> NodeType:
|
||||||
|
cached_prefix = tuple(path[:prefix_len])
|
||||||
|
cached_root = self._cache.get(cached_prefix) # type: Optional[NodeType]
|
||||||
|
if cached_root is None:
|
||||||
|
cached_root = new_root()
|
||||||
|
cached_root.derive_path(cached_prefix)
|
||||||
|
self._cache.insert(cached_prefix, cached_root)
|
||||||
|
|
||||||
|
node = cached_root.clone()
|
||||||
|
node.derive_path(path[prefix_len:])
|
||||||
|
return node
|
||||||
|
|
||||||
|
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
|
||||||
|
self.verify_path(path)
|
||||||
|
return self._derive_with_cache(
|
||||||
|
prefix_len=3,
|
||||||
|
path=path,
|
||||||
|
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
||||||
|
)
|
||||||
|
|
||||||
|
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
|
||||||
|
if not device.unsafe_prompts_allowed() and not any(
|
||||||
|
ns == path[: len(ns)] for ns in self.slip21_namespaces
|
||||||
|
):
|
||||||
|
raise FORBIDDEN_KEY_PATH
|
||||||
|
|
||||||
|
return self._derive_with_cache(
|
||||||
|
prefix_len=1, path=path, new_root=lambda: Slip21Node(seed=self.seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self) -> "Keychain":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
|
||||||
|
self.__del__()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_keychain(
|
||||||
|
ctx: wire.Context,
|
||||||
|
curve: str,
|
||||||
|
namespaces: Sequence[paths.Bip32Path],
|
||||||
|
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
||||||
|
) -> Keychain:
|
||||||
|
seed = await get_seed(ctx)
|
||||||
|
keychain = Keychain(seed, curve, namespaces, slip21_namespaces)
|
||||||
|
return keychain
|
||||||
|
|
||||||
|
|
||||||
|
def with_slip44_keychain(
|
||||||
|
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
|
||||||
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
|
namespaces = [[44 | HARDENED, slip44 | HARDENED]]
|
||||||
|
if allow_testnet:
|
||||||
|
namespaces.append([44 | HARDENED, 1 | HARDENED])
|
||||||
|
|
||||||
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
|
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||||
|
keychain = await get_keychain(ctx, curve, namespaces)
|
||||||
|
with keychain:
|
||||||
|
return await func(ctx, msg, keychain)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
@ -4,29 +4,35 @@ from trezor import ui
|
|||||||
from trezor.messages import ButtonRequestType
|
from trezor.messages import ButtonRequestType
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
|
|
||||||
from apps.common import HARDENED
|
from . import HARDENED
|
||||||
from apps.common.confirm import require_confirm
|
from .confirm import require_confirm
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Callable, List
|
from typing import Any, Callable, List, Sequence, TypeVar
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from apps.common import seed
|
|
||||||
|
# XXX this is a circular import, but it's only for typing
|
||||||
|
from .keychain import Keychain
|
||||||
|
|
||||||
|
Bip32Path = Sequence[int]
|
||||||
|
Slip21Path = Sequence[bytes]
|
||||||
|
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
||||||
|
|
||||||
|
|
||||||
async def validate_path(
|
async def validate_path(
|
||||||
ctx: wire.Context,
|
ctx: wire.Context,
|
||||||
validate_func: Callable[..., bool],
|
validate_func: Callable[..., bool],
|
||||||
keychain: seed.Keychain,
|
keychain: Keychain,
|
||||||
path: List[int],
|
path: List[int],
|
||||||
curve: str,
|
curve: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
keychain.match_path(path)
|
keychain.verify_path(path)
|
||||||
if not validate_func(path, **kwargs):
|
if not validate_func(path, **kwargs):
|
||||||
await show_path_warning(ctx, path)
|
await show_path_warning(ctx, path)
|
||||||
|
|
||||||
|
|
||||||
async def show_path_warning(ctx: wire.Context, path: List[int]) -> None:
|
async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None:
|
||||||
text = Text("Confirm path", ui.ICON_WRONG, ui.RED)
|
text = Text("Confirm path", ui.ICON_WRONG, ui.RED)
|
||||||
text.normal("Path")
|
text.normal("Path")
|
||||||
text.mono(*break_address_n_to_lines(path))
|
text.mono(*break_address_n_to_lines(path))
|
||||||
@ -35,7 +41,7 @@ async def show_path_warning(ctx: wire.Context, path: List[int]) -> None:
|
|||||||
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
|
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
|
||||||
|
|
||||||
|
|
||||||
def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
|
def validate_path_for_get_public_key(path: Bip32Path, slip44_id: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if path has at least three hardened items and slip44 id matches.
|
Checks if path has at least three hardened items and slip44 id matches.
|
||||||
The path is allowed to have more than three items, but all the following
|
The path is allowed to have more than three items, but all the following
|
||||||
@ -61,7 +67,11 @@ def is_hardened(i: int) -> bool:
|
|||||||
return bool(i & HARDENED)
|
return bool(i & HARDENED)
|
||||||
|
|
||||||
|
|
||||||
def break_address_n_to_lines(address_n: list) -> list:
|
def path_is_hardened(address_n: Bip32Path) -> bool:
|
||||||
|
return all(is_hardened(n) for n in address_n)
|
||||||
|
|
||||||
|
|
||||||
|
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
|
||||||
def path_item(i: int) -> str:
|
def path_item(i: int) -> str:
|
||||||
if i & HARDENED:
|
if i & HARDENED:
|
||||||
return str(i ^ HARDENED) + "'"
|
return str(i ^ HARDENED) + "'"
|
||||||
|
@ -2,39 +2,11 @@ from storage import cache, device
|
|||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip32, hashlib, hmac
|
from trezor.crypto import bip32, hashlib, hmac
|
||||||
|
|
||||||
from apps.common import HARDENED, mnemonic
|
from . import mnemonic
|
||||||
from apps.common.passphrase import get as get_passphrase
|
from .passphrase import get as get_passphrase
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import (
|
from .paths import Bip32Path, Slip21Path
|
||||||
Any,
|
|
||||||
Awaitable,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
Bip32Path = List[int]
|
|
||||||
Slip21Path = List[bytes]
|
|
||||||
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
|
||||||
|
|
||||||
Namespace = Tuple[str, PathType]
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
class NodeType(Protocol[PathType]):
|
|
||||||
def __del__(self) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def derive_path(self, path: PathType) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def clone(self: T) -> T:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Slip21Node:
|
class Slip21Node:
|
||||||
@ -63,55 +35,8 @@ class Slip21Node:
|
|||||||
return Slip21Node(data=self.data)
|
return Slip21Node(data=self.data)
|
||||||
|
|
||||||
|
|
||||||
class Keychain:
|
|
||||||
def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
|
|
||||||
self.seed = seed
|
|
||||||
self.namespaces = namespaces # type: Sequence[Namespace]
|
|
||||||
self.roots = {} # type: Dict[int, NodeType]
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
for root in self.roots.values():
|
|
||||||
root.__del__()
|
|
||||||
del self.roots
|
|
||||||
del self.seed
|
|
||||||
|
|
||||||
def match_path(self, path: PathType) -> Tuple[int, PathType]:
|
|
||||||
for i, (curve, ns) in enumerate(self.namespaces):
|
|
||||||
if path[: len(ns)] == ns:
|
|
||||||
if "ed25519" in curve and not _path_hardened(path):
|
|
||||||
raise wire.DataError("Forbidden key path")
|
|
||||||
return i, path[len(ns) :]
|
|
||||||
|
|
||||||
raise wire.DataError("Forbidden key path")
|
|
||||||
|
|
||||||
def _new_root(self, curve: str) -> NodeType:
|
|
||||||
if curve == "slip21":
|
|
||||||
return Slip21Node(self.seed)
|
|
||||||
else:
|
|
||||||
return bip32.from_seed(self.seed, curve)
|
|
||||||
|
|
||||||
def derive(self, path: PathType) -> NodeType:
|
|
||||||
root_index, suffix = self.match_path(path)
|
|
||||||
|
|
||||||
if root_index not in self.roots:
|
|
||||||
curve, prefix = self.namespaces[root_index]
|
|
||||||
root = self._new_root(curve)
|
|
||||||
root.derive_path(prefix)
|
|
||||||
self.roots[root_index] = root
|
|
||||||
|
|
||||||
node = self.roots[root_index].clone()
|
|
||||||
node.derive_path(suffix)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def __enter__(self) -> "Keychain":
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
|
|
||||||
self.__del__()
|
|
||||||
|
|
||||||
|
|
||||||
@cache.stored_async(cache.APP_COMMON_SEED)
|
@cache.stored_async(cache.APP_COMMON_SEED)
|
||||||
async def _get_seed(ctx: wire.Context) -> bytes:
|
async def get_seed(ctx: wire.Context) -> bytes:
|
||||||
if not device.is_initialized():
|
if not device.is_initialized():
|
||||||
raise wire.NotInitialized("Device is not initialized")
|
raise wire.NotInitialized("Device is not initialized")
|
||||||
passphrase = await get_passphrase(ctx)
|
passphrase = await get_passphrase(ctx)
|
||||||
@ -125,12 +50,6 @@ def _get_seed_without_passphrase() -> bytes:
|
|||||||
return mnemonic.get_seed(progress_bar=False)
|
return mnemonic.get_seed(progress_bar=False)
|
||||||
|
|
||||||
|
|
||||||
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
|
|
||||||
seed = await _get_seed(ctx)
|
|
||||||
keychain = Keychain(seed, namespaces)
|
|
||||||
return keychain
|
|
||||||
|
|
||||||
|
|
||||||
def derive_node_without_passphrase(
|
def derive_node_without_passphrase(
|
||||||
path: Bip32Path, curve_name: str = "secp256k1"
|
path: Bip32Path, curve_name: str = "secp256k1"
|
||||||
) -> bip32.HDNode:
|
) -> bip32.HDNode:
|
||||||
@ -150,35 +69,3 @@ def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
|
|||||||
def remove_ed25519_prefix(pubkey: bytes) -> bytes:
|
def remove_ed25519_prefix(pubkey: bytes) -> bytes:
|
||||||
# 0x01 prefix is not part of the actual public key, hence removed
|
# 0x01 prefix is not part of the actual public key, hence removed
|
||||||
return pubkey[1:]
|
return pubkey[1:]
|
||||||
|
|
||||||
|
|
||||||
def _path_hardened(path: list) -> bool:
|
|
||||||
return all(i & HARDENED for i in path)
|
|
||||||
|
|
||||||
|
|
||||||
if False:
|
|
||||||
from protobuf import MessageType
|
|
||||||
|
|
||||||
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
|
||||||
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
|
||||||
|
|
||||||
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
|
|
||||||
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
|
|
||||||
|
|
||||||
|
|
||||||
def with_slip44_keychain(
|
|
||||||
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
|
|
||||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
|
||||||
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
|
|
||||||
if allow_testnet:
|
|
||||||
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
|
|
||||||
|
|
||||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
|
||||||
keychain = await get_keychain(ctx, namespaces)
|
|
||||||
with keychain:
|
|
||||||
return await func(ctx, msg, keychain)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
Loading…
Reference in New Issue
Block a user