1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 20:11:00 +00:00

core: refactor keychain API, introduce SLIP44 decorator

This commit is contained in:
matejcik 2020-04-20 11:37:09 +02:00 committed by matejcik
parent 8c4cb58098
commit 7541d529a3
2 changed files with 114 additions and 88 deletions

View File

@ -21,7 +21,7 @@ async def validate_path(
curve: str, curve: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
keychain.validate_path(path, curve) keychain.match_path(path)
if not validate_func(path, **kwargs): if not validate_func(path, **kwargs):
await show_path_warning(ctx, path) await show_path_warning(ctx, path)
@ -58,10 +58,7 @@ def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
def is_hardened(i: int) -> bool: def is_hardened(i: int) -> bool:
if i & HARDENED: return bool(i & HARDENED)
return True
else:
return False
def break_address_n_to_lines(address_n: list) -> list: def break_address_n_to_lines(address_n: list) -> list:

View File

@ -2,26 +2,56 @@ import storage
from storage import cache from storage import cache
from trezor import wire from trezor import wire
from trezor.crypto import bip32, hashlib, hmac from trezor.crypto import bip32, hashlib, hmac
from trezor.crypto.curve import secp256k1
from apps.common import HARDENED, mnemonic from apps.common import HARDENED, mnemonic
from apps.common.passphrase import get as get_passphrase from apps.common.passphrase import get as get_passphrase
if False: if False:
from typing import List, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Sequence,
Tuple,
TypeVar,
)
from typing_extensions import Protocol
Bip32Path = List[int]
Slip21Path = List[bytes]
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
Namespace = Tuple[str, PathType]
T = TypeVar("T")
class NodeType(Protocol[PathType]):
def __del__(self) -> None:
...
def derive_path(self, path: PathType) -> None:
...
def clone(self: T) -> T:
...
class Slip21Node: class Slip21Node:
def __init__(self, seed: bytes = None) -> None: def __init__(self, seed: bytes = None, data: bytes = None) -> None:
if seed is not None: assert seed is None or data is None, "Specify exactly one of: seed, data"
if data is not None:
self.data = data
elif seed is not None:
self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest() self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest()
else: else:
self.data = b"" raise ValueError # neither seed nor data specified
def __del__(self) -> None: def __del__(self) -> None:
del self.data del self.data
def derive_path(self, path: list) -> None: def derive_path(self, path: Slip21Path) -> None:
for label in path: for label in path:
h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512) h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512)
h.update(label) h.update(label)
@ -31,117 +61,88 @@ class Slip21Node:
return self.data[32:64] return self.data[32:64]
def clone(self) -> "Slip21Node": def clone(self) -> "Slip21Node":
node = Slip21Node() return Slip21Node(data=self.data)
node.data = self.data
return node
class Keychain: class Keychain:
""" def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
Keychain provides an API for deriving HD keys from previously allowed
key-spaces.
"""
def __init__(self, seed: bytes, namespaces: list):
self.seed = seed self.seed = seed
self.namespaces = namespaces self.namespaces = namespaces # type: Sequence[Namespace]
self.roots = [None] * len( self.roots = {} # type: Dict[int, NodeType]
namespaces
) # type: List[Union[bip32.HDNode, Slip21Node, None]]
def __del__(self) -> None: def __del__(self) -> None:
for root in self.roots: for root in self.roots.values():
if root is not None and hasattr(root, "__del__"): root.__del__()
root.__del__()
del self.roots del self.roots
del self.seed del self.seed
def validate_path(self, checked_path: list, checked_curve: str) -> None: def match_path(self, path: PathType) -> Tuple[int, PathType]:
for curve, *path in self.namespaces: for i, (curve, ns) in enumerate(self.namespaces):
if path == checked_path[: len(path)] and curve == checked_curve: if path[: len(ns)] == ns:
if "ed25519" in curve and not _path_hardened(checked_path): if "ed25519" in curve and not _path_hardened(path):
break raise wire.DataError("Forbidden key path")
return return i, path[len(ns) :]
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
def derive( def _new_root(self, curve: str) -> NodeType:
self, node_path: list, curve_name: str = "secp256k1" if curve == "slip21":
) -> Union[bip32.HDNode, Slip21Node]: return Slip21Node(self.seed)
if "ed25519" in curve_name and not _path_hardened(node_path):
raise wire.DataError("Forbidden key path")
# find the root node index
root_index = 0
for curve, *path in self.namespaces:
prefix = node_path[: len(path)]
suffix = node_path[len(path) :]
if curve == curve_name and path == prefix:
break
root_index += 1
else: else:
raise wire.DataError("Forbidden key path") return bip32.from_seed(self.seed, curve)
# create the root node if not cached def derive(self, path: PathType) -> NodeType:
root = self.roots[root_index] root_index, suffix = self.match_path(path)
if root is None:
if curve_name != "slip21": if root_index not in self.roots:
root = bip32.from_seed(self.seed, curve_name) curve, prefix = self.namespaces[root_index]
else: root = self._new_root(curve)
root = Slip21Node(self.seed) root.derive_path(prefix)
root.derive_path(path)
self.roots[root_index] = root self.roots[root_index] = root
# derive child node from the root node = self.roots[root_index].clone()
node = root.clone()
node.derive_path(suffix) node.derive_path(suffix)
return node return node
def derive_slip77_blinding_private_key(self, script: bytes) -> bytes: def __enter__(self) -> "Keychain":
"""Following the derivation by Elements/Liquid.""" return self
master_node = self.derive(node_path=[b"SLIP-0077"], curve_name="slip21")
assert isinstance(master_node, Slip21Node)
return hmac.new(
key=master_node.key(), msg=script, digestmod=hashlib.sha256
).digest()
def derive_slip77_blinding_public_key(self, script: bytes) -> bytes: def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
private_key = self.derive_slip77_blinding_private_key(script) self.__del__()
return secp256k1.publickey(private_key)
async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: @cache.stored_async(cache.APP_COMMON_SEED)
async def _get_seed(ctx: wire.Context) -> bytes:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED) passphrase = await get_passphrase(ctx)
if seed is None: return mnemonic.get_seed(passphrase)
passphrase = await get_passphrase(ctx)
seed = mnemonic.get_seed(passphrase)
cache.set(cache.APP_COMMON_SEED, seed) @cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage.is_initialized():
raise Exception("Device is not initialized")
return mnemonic.get_seed(progress_bar=False)
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
seed = await _get_seed(ctx)
keychain = Keychain(seed, namespaces) keychain = Keychain(seed, namespaces)
return keychain return keychain
def derive_node_without_passphrase( def derive_node_without_passphrase(
path: list, curve_name: str = "secp256k1" path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode: ) -> bip32.HDNode:
if not storage.is_initialized(): seed = _get_seed_without_passphrase()
raise Exception("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
if seed is None:
seed = mnemonic.get_seed(progress_bar=False)
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
node.derive_path(path) node.derive_path(path)
return node return node
def derive_slip21_node_without_passphrase(path: list) -> Slip21Node: def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
if not storage.is_initialized(): seed = _get_seed_without_passphrase()
raise Exception("Device is not initialized")
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
if seed is None:
seed = mnemonic.get_seed(progress_bar=False)
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
node = Slip21Node(seed) node = Slip21Node(seed)
node.derive_path(path) node.derive_path(path)
return node return node
@ -154,3 +155,31 @@ def remove_ed25519_prefix(pubkey: bytes) -> bytes:
def _path_hardened(path: list) -> bool: def _path_hardened(path: list) -> bool:
return all(i & HARDENED for i in path) return all(i & HARDENED for i in path)
if False:
from protobuf import MessageType
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
def with_slip44_keychain(
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
if allow_testnet:
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, namespaces)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
return decorator