1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-31 18:40:56 +00:00

core: refactor keychain API, introduce SLIP44 decorator

This commit is contained in:
matejcik 2020-04-20 11:37:09 +02:00 committed by matejcik
parent 8c4cb58098
commit 7541d529a3
2 changed files with 114 additions and 88 deletions

View File

@ -21,7 +21,7 @@ async def validate_path(
curve: str,
**kwargs: Any,
) -> None:
keychain.validate_path(path, curve)
keychain.match_path(path)
if not validate_func(path, **kwargs):
await show_path_warning(ctx, path)
@ -58,10 +58,7 @@ def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
def is_hardened(i: int) -> bool:
if i & HARDENED:
return True
else:
return False
return bool(i & HARDENED)
def break_address_n_to_lines(address_n: list) -> list:

View File

@ -2,26 +2,56 @@ import storage
from storage import cache
from trezor import wire
from trezor.crypto import bip32, hashlib, hmac
from trezor.crypto.curve import secp256k1
from apps.common import HARDENED, mnemonic
from apps.common.passphrase import get as get_passphrase
if False:
from typing import List, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Sequence,
Tuple,
TypeVar,
)
from typing_extensions import Protocol
Bip32Path = List[int]
Slip21Path = List[bytes]
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
Namespace = Tuple[str, PathType]
T = TypeVar("T")
class NodeType(Protocol[PathType]):
def __del__(self) -> None:
...
def derive_path(self, path: PathType) -> None:
...
def clone(self: T) -> T:
...
class Slip21Node:
def __init__(self, seed: bytes = None) -> None:
if seed is not None:
def __init__(self, seed: bytes = None, data: bytes = None) -> None:
assert seed is None or data is None, "Specify exactly one of: seed, data"
if data is not None:
self.data = data
elif seed is not None:
self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest()
else:
self.data = b""
raise ValueError # neither seed nor data specified
def __del__(self) -> None:
del self.data
def derive_path(self, path: list) -> None:
def derive_path(self, path: Slip21Path) -> None:
for label in path:
h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512)
h.update(label)
@ -31,117 +61,88 @@ class Slip21Node:
return self.data[32:64]
def clone(self) -> "Slip21Node":
node = Slip21Node()
node.data = self.data
return node
return Slip21Node(data=self.data)
class Keychain:
"""
Keychain provides an API for deriving HD keys from previously allowed
key-spaces.
"""
def __init__(self, seed: bytes, namespaces: list):
def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
self.seed = seed
self.namespaces = namespaces
self.roots = [None] * len(
namespaces
) # type: List[Union[bip32.HDNode, Slip21Node, None]]
self.namespaces = namespaces # type: Sequence[Namespace]
self.roots = {} # type: Dict[int, NodeType]
def __del__(self) -> None:
for root in self.roots:
if root is not None and hasattr(root, "__del__"):
root.__del__()
for root in self.roots.values():
root.__del__()
del self.roots
del self.seed
def validate_path(self, checked_path: list, checked_curve: str) -> None:
for curve, *path in self.namespaces:
if path == checked_path[: len(path)] and curve == checked_curve:
if "ed25519" in curve and not _path_hardened(checked_path):
break
return
def match_path(self, path: PathType) -> Tuple[int, PathType]:
for i, (curve, ns) in enumerate(self.namespaces):
if path[: len(ns)] == ns:
if "ed25519" in curve and not _path_hardened(path):
raise wire.DataError("Forbidden key path")
return i, path[len(ns) :]
raise wire.DataError("Forbidden key path")
def derive(
self, node_path: list, curve_name: str = "secp256k1"
) -> Union[bip32.HDNode, Slip21Node]:
if "ed25519" in curve_name and not _path_hardened(node_path):
raise wire.DataError("Forbidden key path")
# find the root node index
root_index = 0
for curve, *path in self.namespaces:
prefix = node_path[: len(path)]
suffix = node_path[len(path) :]
if curve == curve_name and path == prefix:
break
root_index += 1
def _new_root(self, curve: str) -> NodeType:
if curve == "slip21":
return Slip21Node(self.seed)
else:
raise wire.DataError("Forbidden key path")
return bip32.from_seed(self.seed, curve)
# create the root node if not cached
root = self.roots[root_index]
if root is None:
if curve_name != "slip21":
root = bip32.from_seed(self.seed, curve_name)
else:
root = Slip21Node(self.seed)
root.derive_path(path)
def derive(self, path: PathType) -> NodeType:
root_index, suffix = self.match_path(path)
if root_index not in self.roots:
curve, prefix = self.namespaces[root_index]
root = self._new_root(curve)
root.derive_path(prefix)
self.roots[root_index] = root
# derive child node from the root
node = root.clone()
node = self.roots[root_index].clone()
node.derive_path(suffix)
return node
def derive_slip77_blinding_private_key(self, script: bytes) -> bytes:
"""Following the derivation by Elements/Liquid."""
master_node = self.derive(node_path=[b"SLIP-0077"], curve_name="slip21")
assert isinstance(master_node, Slip21Node)
return hmac.new(
key=master_node.key(), msg=script, digestmod=hashlib.sha256
).digest()
def __enter__(self) -> "Keychain":
return self
def derive_slip77_blinding_public_key(self, script: bytes) -> bytes:
private_key = self.derive_slip77_blinding_private_key(script)
return secp256k1.publickey(private_key)
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
self.__del__()
async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
@cache.stored_async(cache.APP_COMMON_SEED)
async def _get_seed(ctx: wire.Context) -> bytes:
if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED)
if seed is None:
passphrase = await get_passphrase(ctx)
seed = mnemonic.get_seed(passphrase)
cache.set(cache.APP_COMMON_SEED, seed)
passphrase = await get_passphrase(ctx)
return mnemonic.get_seed(passphrase)
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage.is_initialized():
raise Exception("Device is not initialized")
return mnemonic.get_seed(progress_bar=False)
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
seed = await _get_seed(ctx)
keychain = Keychain(seed, namespaces)
return keychain
def derive_node_without_passphrase(
path: list, curve_name: str = "secp256k1"
path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode:
if not storage.is_initialized():
raise Exception("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
if seed is None:
seed = mnemonic.get_seed(progress_bar=False)
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
seed = _get_seed_without_passphrase()
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)
return node
def derive_slip21_node_without_passphrase(path: list) -> Slip21Node:
if not storage.is_initialized():
raise Exception("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
if seed is None:
seed = mnemonic.get_seed(progress_bar=False)
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
seed = _get_seed_without_passphrase()
node = Slip21Node(seed)
node.derive_path(path)
return node
@ -154,3 +155,31 @@ def remove_ed25519_prefix(pubkey: bytes) -> bytes:
def _path_hardened(path: list) -> bool:
return all(i & HARDENED for i in path)
if False:
from protobuf import MessageType
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
def with_slip44_keychain(
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
if allow_testnet:
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, namespaces)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
return decorator