1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-19 04:48:12 +00:00

core: refactor keychain to only support one curve at a time

also make a cleaner distinction between keychain, seed, path

This enables using `unsafe_prompts`, because with the original code, if
there was no namespace match, we wouldn't know which curve to use.

For ease of implementation, we use a LRU cache for derived keys,
instead of the original design "one cache entry per namespace".

SLIP21 is now treated completely separately, via `slip21_namespaces` and
`derive_slip21` method.
If more slip21-like things come in the future, we can instead hang them
on the keychain: put a per-curve Keychain object accessible by
`keychain[curve_name].derive()`, and the majority usecase will just pass
around `keychain[curve_name]` instead of having to specify the curve in
every `derive()` call.

Or alternately we'll just specify the curve in every `derive()` call,
whichever seems more appropriate.
This commit is contained in:
matejcik 2020-07-14 12:15:05 +02:00 committed by matejcik
parent fa757f4b7f
commit ff4ec2185e
4 changed files with 206 additions and 131 deletions

View File

@ -7,8 +7,6 @@ from apps.common import mnemonic
from apps.common.passphrase import get as get_passphrase from apps.common.passphrase import get as get_passphrase
if False: if False:
from typing import Tuple
from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain
@ -18,15 +16,14 @@ class Keychain:
def __init__(self, root: bip32.HDNode) -> None: def __init__(self, root: bip32.HDNode) -> None:
self.root = root self.root = root
def match_path(self, path: Bip32Path) -> Tuple[int, Bip32Path]: def verify_path(self, path: Bip32Path) -> None:
if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE: if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE:
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
return 0, path[len(SEED_NAMESPACE) :]
def derive(self, node_path: Bip32Path) -> bip32.HDNode: def derive(self, node_path: Bip32Path) -> bip32.HDNode:
_, suffix = self.match_path(node_path)
# derive child node from the root # derive child node from the root
node = self.root.clone() node = self.root.clone()
suffix = node_path[len(SEED_NAMESPACE) :]
for i in suffix: for i in suffix:
node.derive_cardano(i) node.derive_cardano(i)
return node return node

View File

@ -0,0 +1,181 @@
from storage import device
from trezor import wire
from trezor.crypto import bip32
from . import HARDENED, paths
from .seed import Slip21Node, get_seed
if False:
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
TypeVar,
)
from typing_extensions import Protocol
from protobuf import MessageType
T = TypeVar("T")
class NodeProtocol(Protocol[paths.PathType]):
def derive_path(self, path: paths.PathType) -> None:
...
def clone(self: T) -> T:
...
def __del__(self) -> None:
...
NodeType = TypeVar("NodeType", bound=NodeProtocol)
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
class Deletable(Protocol):
def __del__(self) -> None:
...
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path")
class LRUCache:
def __init__(self, size: int) -> None:
self.size = size
self.cache_keys = [] # type: List[Any]
self.cache = {} # type: Dict[Any, Deletable]
def insert(self, key: Any, value: Deletable) -> None:
if key in self.cache_keys:
self.cache_keys.remove(key)
self.cache_keys.insert(0, key)
self.cache[key] = value
if len(self.cache_keys) > self.size:
dropped_key = self.cache_keys.pop()
self.cache[dropped_key].__del__()
del self.cache[dropped_key]
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
self.cache_keys.remove(key)
self.cache_keys.insert(0, key)
return self.cache[key]
def __del__(self) -> None:
for value in self.cache.values():
value.__del__()
self.cache.clear()
self.cache_keys.clear()
del self.cache
class Keychain:
def __init__(
self,
seed: bytes,
curve: str,
namespaces: Sequence[paths.Bip32Path],
slip21_namespaces: Sequence[paths.Slip21Path] = (),
) -> None:
self.seed = seed
self.curve = curve
self.namespaces = namespaces
self.slip21_namespaces = slip21_namespaces
self._cache = LRUCache(10)
def __del__(self) -> None:
self._cache.__del__()
del self._cache
del self.seed
def verify_path(self, path: paths.Bip32Path) -> None:
if "ed25519" in self.curve and not paths.path_is_hardened(path):
raise FORBIDDEN_KEY_PATH
if device.unsafe_prompts_allowed():
return
if any(ns == path[: len(ns)] for ns in self.namespaces):
return
raise FORBIDDEN_KEY_PATH
def _derive_with_cache(
self, prefix_len: int, path: paths.PathType, new_root: Callable[[], NodeType],
) -> NodeType:
cached_prefix = tuple(path[:prefix_len])
cached_root = self._cache.get(cached_prefix) # type: Optional[NodeType]
if cached_root is None:
cached_root = new_root()
cached_root.derive_path(cached_prefix)
self._cache.insert(cached_prefix, cached_root)
node = cached_root.clone()
node.derive_path(path[prefix_len:])
return node
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
self.verify_path(path)
return self._derive_with_cache(
prefix_len=3,
path=path,
new_root=lambda: bip32.from_seed(self.seed, self.curve),
)
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
if not device.unsafe_prompts_allowed() and not any(
ns == path[: len(ns)] for ns in self.slip21_namespaces
):
raise FORBIDDEN_KEY_PATH
return self._derive_with_cache(
prefix_len=1, path=path, new_root=lambda: Slip21Node(seed=self.seed),
)
def __enter__(self) -> "Keychain":
return self
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
self.__del__()
async def get_keychain(
ctx: wire.Context,
curve: str,
namespaces: Sequence[paths.Bip32Path],
slip21_namespaces: Sequence[paths.Slip21Path] = (),
) -> Keychain:
seed = await get_seed(ctx)
keychain = Keychain(seed, curve, namespaces, slip21_namespaces)
return keychain
def with_slip44_keychain(
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
namespaces = [[44 | HARDENED, slip44 | HARDENED]]
if allow_testnet:
namespaces.append([44 | HARDENED, 1 | HARDENED])
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, curve, namespaces)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
return decorator

View File

@ -4,29 +4,35 @@ from trezor import ui
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import HARDENED from . import HARDENED
from apps.common.confirm import require_confirm from .confirm import require_confirm
if False: if False:
from typing import Any, Callable, List from typing import Any, Callable, List, Sequence, TypeVar
from trezor import wire from trezor import wire
from apps.common import seed
# XXX this is a circular import, but it's only for typing
from .keychain import Keychain
Bip32Path = Sequence[int]
Slip21Path = Sequence[bytes]
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
async def validate_path( async def validate_path(
ctx: wire.Context, ctx: wire.Context,
validate_func: Callable[..., bool], validate_func: Callable[..., bool],
keychain: seed.Keychain, keychain: Keychain,
path: List[int], path: List[int],
curve: str, curve: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
keychain.match_path(path) keychain.verify_path(path)
if not validate_func(path, **kwargs): if not validate_func(path, **kwargs):
await show_path_warning(ctx, path) await show_path_warning(ctx, path)
async def show_path_warning(ctx: wire.Context, path: List[int]) -> None: async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None:
text = Text("Confirm path", ui.ICON_WRONG, ui.RED) text = Text("Confirm path", ui.ICON_WRONG, ui.RED)
text.normal("Path") text.normal("Path")
text.mono(*break_address_n_to_lines(path)) text.mono(*break_address_n_to_lines(path))
@ -35,7 +41,7 @@ async def show_path_warning(ctx: wire.Context, path: List[int]) -> None:
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath) await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool: def validate_path_for_get_public_key(path: Bip32Path, slip44_id: int) -> bool:
""" """
Checks if path has at least three hardened items and slip44 id matches. Checks if path has at least three hardened items and slip44 id matches.
The path is allowed to have more than three items, but all the following The path is allowed to have more than three items, but all the following
@ -61,7 +67,11 @@ def is_hardened(i: int) -> bool:
return bool(i & HARDENED) return bool(i & HARDENED)
def break_address_n_to_lines(address_n: list) -> list: def path_is_hardened(address_n: Bip32Path) -> bool:
return all(is_hardened(n) for n in address_n)
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
def path_item(i: int) -> str: def path_item(i: int) -> str:
if i & HARDENED: if i & HARDENED:
return str(i ^ HARDENED) + "'" return str(i ^ HARDENED) + "'"

View File

@ -2,39 +2,11 @@ from storage import cache, device
from trezor import wire from trezor import wire
from trezor.crypto import bip32, hashlib, hmac from trezor.crypto import bip32, hashlib, hmac
from apps.common import HARDENED, mnemonic from . import mnemonic
from apps.common.passphrase import get as get_passphrase from .passphrase import get as get_passphrase
if False: if False:
from typing import ( from .paths import Bip32Path, Slip21Path
Any,
Awaitable,
Callable,
Dict,
List,
Sequence,
Tuple,
TypeVar,
)
from typing_extensions import Protocol
Bip32Path = List[int]
Slip21Path = List[bytes]
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
Namespace = Tuple[str, PathType]
T = TypeVar("T")
class NodeType(Protocol[PathType]):
def __del__(self) -> None:
...
def derive_path(self, path: PathType) -> None:
...
def clone(self: T) -> T:
...
class Slip21Node: class Slip21Node:
@ -63,55 +35,8 @@ class Slip21Node:
return Slip21Node(data=self.data) return Slip21Node(data=self.data)
class Keychain:
def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
self.seed = seed
self.namespaces = namespaces # type: Sequence[Namespace]
self.roots = {} # type: Dict[int, NodeType]
def __del__(self) -> None:
for root in self.roots.values():
root.__del__()
del self.roots
del self.seed
def match_path(self, path: PathType) -> Tuple[int, PathType]:
for i, (curve, ns) in enumerate(self.namespaces):
if path[: len(ns)] == ns:
if "ed25519" in curve and not _path_hardened(path):
raise wire.DataError("Forbidden key path")
return i, path[len(ns) :]
raise wire.DataError("Forbidden key path")
def _new_root(self, curve: str) -> NodeType:
if curve == "slip21":
return Slip21Node(self.seed)
else:
return bip32.from_seed(self.seed, curve)
def derive(self, path: PathType) -> NodeType:
root_index, suffix = self.match_path(path)
if root_index not in self.roots:
curve, prefix = self.namespaces[root_index]
root = self._new_root(curve)
root.derive_path(prefix)
self.roots[root_index] = root
node = self.roots[root_index].clone()
node.derive_path(suffix)
return node
def __enter__(self) -> "Keychain":
return self
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
self.__del__()
@cache.stored_async(cache.APP_COMMON_SEED) @cache.stored_async(cache.APP_COMMON_SEED)
async def _get_seed(ctx: wire.Context) -> bytes: async def get_seed(ctx: wire.Context) -> bytes:
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
passphrase = await get_passphrase(ctx) passphrase = await get_passphrase(ctx)
@ -125,12 +50,6 @@ def _get_seed_without_passphrase() -> bytes:
return mnemonic.get_seed(progress_bar=False) return mnemonic.get_seed(progress_bar=False)
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
seed = await _get_seed(ctx)
keychain = Keychain(seed, namespaces)
return keychain
def derive_node_without_passphrase( def derive_node_without_passphrase(
path: Bip32Path, curve_name: str = "secp256k1" path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode: ) -> bip32.HDNode:
@ -150,35 +69,3 @@ def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
def remove_ed25519_prefix(pubkey: bytes) -> bytes: def remove_ed25519_prefix(pubkey: bytes) -> bytes:
# 0x01 prefix is not part of the actual public key, hence removed # 0x01 prefix is not part of the actual public key, hence removed
return pubkey[1:] return pubkey[1:]
def _path_hardened(path: list) -> bool:
return all(i & HARDENED for i in path)
if False:
from protobuf import MessageType
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
def with_slip44_keychain(
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
if allow_testnet:
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, namespaces)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
return decorator