mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 20:11:00 +00:00
core: refactor keychain API, introduce SLIP44 decorator
This commit is contained in:
parent
8c4cb58098
commit
7541d529a3
@ -21,7 +21,7 @@ async def validate_path(
|
|||||||
curve: str,
|
curve: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
keychain.validate_path(path, curve)
|
keychain.match_path(path)
|
||||||
if not validate_func(path, **kwargs):
|
if not validate_func(path, **kwargs):
|
||||||
await show_path_warning(ctx, path)
|
await show_path_warning(ctx, path)
|
||||||
|
|
||||||
@ -58,10 +58,7 @@ def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_hardened(i: int) -> bool:
|
def is_hardened(i: int) -> bool:
|
||||||
if i & HARDENED:
|
return bool(i & HARDENED)
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def break_address_n_to_lines(address_n: list) -> list:
|
def break_address_n_to_lines(address_n: list) -> list:
|
||||||
|
@ -2,26 +2,56 @@ import storage
|
|||||||
from storage import cache
|
from storage import cache
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip32, hashlib, hmac
|
from trezor.crypto import bip32, hashlib, hmac
|
||||||
from trezor.crypto.curve import secp256k1
|
|
||||||
|
|
||||||
from apps.common import HARDENED, mnemonic
|
from apps.common import HARDENED, mnemonic
|
||||||
from apps.common.passphrase import get as get_passphrase
|
from apps.common.passphrase import get as get_passphrase
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import List, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
Bip32Path = List[int]
|
||||||
|
Slip21Path = List[bytes]
|
||||||
|
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
||||||
|
|
||||||
|
Namespace = Tuple[str, PathType]
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class NodeType(Protocol[PathType]):
|
||||||
|
def __del__(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def derive_path(self, path: PathType) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def clone(self: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Slip21Node:
|
class Slip21Node:
|
||||||
def __init__(self, seed: bytes = None) -> None:
|
def __init__(self, seed: bytes = None, data: bytes = None) -> None:
|
||||||
if seed is not None:
|
assert seed is None or data is None, "Specify exactly one of: seed, data"
|
||||||
|
if data is not None:
|
||||||
|
self.data = data
|
||||||
|
elif seed is not None:
|
||||||
self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest()
|
self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest()
|
||||||
else:
|
else:
|
||||||
self.data = b""
|
raise ValueError # neither seed nor data specified
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
del self.data
|
del self.data
|
||||||
|
|
||||||
def derive_path(self, path: list) -> None:
|
def derive_path(self, path: Slip21Path) -> None:
|
||||||
for label in path:
|
for label in path:
|
||||||
h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512)
|
h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512)
|
||||||
h.update(label)
|
h.update(label)
|
||||||
@ -31,117 +61,88 @@ class Slip21Node:
|
|||||||
return self.data[32:64]
|
return self.data[32:64]
|
||||||
|
|
||||||
def clone(self) -> "Slip21Node":
|
def clone(self) -> "Slip21Node":
|
||||||
node = Slip21Node()
|
return Slip21Node(data=self.data)
|
||||||
node.data = self.data
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
class Keychain:
|
class Keychain:
|
||||||
"""
|
def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
|
||||||
Keychain provides an API for deriving HD keys from previously allowed
|
|
||||||
key-spaces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, seed: bytes, namespaces: list):
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.namespaces = namespaces
|
self.namespaces = namespaces # type: Sequence[Namespace]
|
||||||
self.roots = [None] * len(
|
self.roots = {} # type: Dict[int, NodeType]
|
||||||
namespaces
|
|
||||||
) # type: List[Union[bip32.HDNode, Slip21Node, None]]
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
for root in self.roots:
|
for root in self.roots.values():
|
||||||
if root is not None and hasattr(root, "__del__"):
|
root.__del__()
|
||||||
root.__del__()
|
|
||||||
del self.roots
|
del self.roots
|
||||||
del self.seed
|
del self.seed
|
||||||
|
|
||||||
def validate_path(self, checked_path: list, checked_curve: str) -> None:
|
def match_path(self, path: PathType) -> Tuple[int, PathType]:
|
||||||
for curve, *path in self.namespaces:
|
for i, (curve, ns) in enumerate(self.namespaces):
|
||||||
if path == checked_path[: len(path)] and curve == checked_curve:
|
if path[: len(ns)] == ns:
|
||||||
if "ed25519" in curve and not _path_hardened(checked_path):
|
if "ed25519" in curve and not _path_hardened(path):
|
||||||
break
|
raise wire.DataError("Forbidden key path")
|
||||||
return
|
return i, path[len(ns) :]
|
||||||
|
|
||||||
raise wire.DataError("Forbidden key path")
|
raise wire.DataError("Forbidden key path")
|
||||||
|
|
||||||
def derive(
|
def _new_root(self, curve: str) -> NodeType:
|
||||||
self, node_path: list, curve_name: str = "secp256k1"
|
if curve == "slip21":
|
||||||
) -> Union[bip32.HDNode, Slip21Node]:
|
return Slip21Node(self.seed)
|
||||||
if "ed25519" in curve_name and not _path_hardened(node_path):
|
|
||||||
raise wire.DataError("Forbidden key path")
|
|
||||||
|
|
||||||
# find the root node index
|
|
||||||
root_index = 0
|
|
||||||
for curve, *path in self.namespaces:
|
|
||||||
prefix = node_path[: len(path)]
|
|
||||||
suffix = node_path[len(path) :]
|
|
||||||
if curve == curve_name and path == prefix:
|
|
||||||
break
|
|
||||||
root_index += 1
|
|
||||||
else:
|
else:
|
||||||
raise wire.DataError("Forbidden key path")
|
return bip32.from_seed(self.seed, curve)
|
||||||
|
|
||||||
# create the root node if not cached
|
def derive(self, path: PathType) -> NodeType:
|
||||||
root = self.roots[root_index]
|
root_index, suffix = self.match_path(path)
|
||||||
if root is None:
|
|
||||||
if curve_name != "slip21":
|
if root_index not in self.roots:
|
||||||
root = bip32.from_seed(self.seed, curve_name)
|
curve, prefix = self.namespaces[root_index]
|
||||||
else:
|
root = self._new_root(curve)
|
||||||
root = Slip21Node(self.seed)
|
root.derive_path(prefix)
|
||||||
root.derive_path(path)
|
|
||||||
self.roots[root_index] = root
|
self.roots[root_index] = root
|
||||||
|
|
||||||
# derive child node from the root
|
node = self.roots[root_index].clone()
|
||||||
node = root.clone()
|
|
||||||
node.derive_path(suffix)
|
node.derive_path(suffix)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def derive_slip77_blinding_private_key(self, script: bytes) -> bytes:
|
def __enter__(self) -> "Keychain":
|
||||||
"""Following the derivation by Elements/Liquid."""
|
return self
|
||||||
master_node = self.derive(node_path=[b"SLIP-0077"], curve_name="slip21")
|
|
||||||
assert isinstance(master_node, Slip21Node)
|
|
||||||
return hmac.new(
|
|
||||||
key=master_node.key(), msg=script, digestmod=hashlib.sha256
|
|
||||||
).digest()
|
|
||||||
|
|
||||||
def derive_slip77_blinding_public_key(self, script: bytes) -> bytes:
|
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
|
||||||
private_key = self.derive_slip77_blinding_private_key(script)
|
self.__del__()
|
||||||
return secp256k1.publickey(private_key)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
|
@cache.stored_async(cache.APP_COMMON_SEED)
|
||||||
|
async def _get_seed(ctx: wire.Context) -> bytes:
|
||||||
if not storage.is_initialized():
|
if not storage.is_initialized():
|
||||||
raise wire.NotInitialized("Device is not initialized")
|
raise wire.NotInitialized("Device is not initialized")
|
||||||
seed = cache.get(cache.APP_COMMON_SEED)
|
passphrase = await get_passphrase(ctx)
|
||||||
if seed is None:
|
return mnemonic.get_seed(passphrase)
|
||||||
passphrase = await get_passphrase(ctx)
|
|
||||||
seed = mnemonic.get_seed(passphrase)
|
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||||
|
def _get_seed_without_passphrase() -> bytes:
|
||||||
|
if not storage.is_initialized():
|
||||||
|
raise Exception("Device is not initialized")
|
||||||
|
return mnemonic.get_seed(progress_bar=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
|
||||||
|
seed = await _get_seed(ctx)
|
||||||
keychain = Keychain(seed, namespaces)
|
keychain = Keychain(seed, namespaces)
|
||||||
return keychain
|
return keychain
|
||||||
|
|
||||||
|
|
||||||
def derive_node_without_passphrase(
|
def derive_node_without_passphrase(
|
||||||
path: list, curve_name: str = "secp256k1"
|
path: Bip32Path, curve_name: str = "secp256k1"
|
||||||
) -> bip32.HDNode:
|
) -> bip32.HDNode:
|
||||||
if not storage.is_initialized():
|
seed = _get_seed_without_passphrase()
|
||||||
raise Exception("Device is not initialized")
|
|
||||||
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
|
||||||
if seed is None:
|
|
||||||
seed = mnemonic.get_seed(progress_bar=False)
|
|
||||||
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
|
|
||||||
node = bip32.from_seed(seed, curve_name)
|
node = bip32.from_seed(seed, curve_name)
|
||||||
node.derive_path(path)
|
node.derive_path(path)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def derive_slip21_node_without_passphrase(path: list) -> Slip21Node:
|
def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
|
||||||
if not storage.is_initialized():
|
seed = _get_seed_without_passphrase()
|
||||||
raise Exception("Device is not initialized")
|
|
||||||
seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
|
||||||
if seed is None:
|
|
||||||
seed = mnemonic.get_seed(progress_bar=False)
|
|
||||||
cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed)
|
|
||||||
node = Slip21Node(seed)
|
node = Slip21Node(seed)
|
||||||
node.derive_path(path)
|
node.derive_path(path)
|
||||||
return node
|
return node
|
||||||
@ -154,3 +155,31 @@ def remove_ed25519_prefix(pubkey: bytes) -> bytes:
|
|||||||
|
|
||||||
def _path_hardened(path: list) -> bool:
|
def _path_hardened(path: list) -> bool:
|
||||||
return all(i & HARDENED for i in path)
|
return all(i & HARDENED for i in path)
|
||||||
|
|
||||||
|
|
||||||
|
if False:
|
||||||
|
from protobuf import MessageType
|
||||||
|
|
||||||
|
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
||||||
|
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
||||||
|
|
||||||
|
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
|
||||||
|
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
|
||||||
|
|
||||||
|
|
||||||
|
def with_slip44_keychain(
|
||||||
|
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
|
||||||
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
|
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
|
||||||
|
if allow_testnet:
|
||||||
|
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
|
||||||
|
|
||||||
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
|
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||||
|
keychain = await get_keychain(ctx, namespaces)
|
||||||
|
with keychain:
|
||||||
|
return await func(ctx, msg, keychain)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
Loading…
Reference in New Issue
Block a user