You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/common/seed.py

185 lines
5.4 KiB

from storage import cache, device
from trezor import wire
from trezor.crypto import bip32, hashlib, hmac
from apps.common import HARDENED, mnemonic
from apps.common.passphrase import get as get_passphrase
if False:
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Sequence,
Tuple,
TypeVar,
)
from typing_extensions import Protocol
Bip32Path = List[int]
Slip21Path = List[bytes]
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
Namespace = Tuple[str, PathType]
T = TypeVar("T")
class NodeType(Protocol[PathType]):
def __del__(self) -> None:
...
def derive_path(self, path: PathType) -> None:
...
def clone(self: T) -> T:
...
class Slip21Node:
def __init__(self, seed: bytes = None, data: bytes = None) -> None:
assert seed is None or data is None, "Specify exactly one of: seed, data"
if data is not None:
self.data = data
elif seed is not None:
self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest()
else:
raise ValueError # neither seed nor data specified
def __del__(self) -> None:
del self.data
def derive_path(self, path: Slip21Path) -> None:
for label in path:
h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512)
h.update(label)
self.data = h.digest()
def key(self) -> bytes:
return self.data[32:64]
def clone(self) -> "Slip21Node":
return Slip21Node(data=self.data)
class Keychain:
def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None:
self.seed = seed
self.namespaces = namespaces # type: Sequence[Namespace]
self.roots = {} # type: Dict[int, NodeType]
def __del__(self) -> None:
for root in self.roots.values():
root.__del__()
del self.roots
del self.seed
def match_path(self, path: PathType) -> Tuple[int, PathType]:
for i, (curve, ns) in enumerate(self.namespaces):
if path[: len(ns)] == ns:
if "ed25519" in curve and not _path_hardened(path):
raise wire.DataError("Forbidden key path")
return i, path[len(ns) :]
raise wire.DataError("Forbidden key path")
def _new_root(self, curve: str) -> NodeType:
if curve == "slip21":
return Slip21Node(self.seed)
else:
return bip32.from_seed(self.seed, curve)
def derive(self, path: PathType) -> NodeType:
root_index, suffix = self.match_path(path)
if root_index not in self.roots:
curve, prefix = self.namespaces[root_index]
root = self._new_root(curve)
root.derive_path(prefix)
self.roots[root_index] = root
node = self.roots[root_index].clone()
node.derive_path(suffix)
return node
def __enter__(self) -> "Keychain":
return self
def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None:
self.__del__()
@cache.stored_async(cache.APP_COMMON_SEED)
async def _get_seed(ctx: wire.Context) -> bytes:
if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
passphrase = await get_passphrase(ctx)
return mnemonic.get_seed(passphrase)
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not device.is_initialized():
raise Exception("Device is not initialized")
return mnemonic.get_seed(progress_bar=False)
async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain:
seed = await _get_seed(ctx)
keychain = Keychain(seed, namespaces)
return keychain
def derive_node_without_passphrase(
path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode:
seed = _get_seed_without_passphrase()
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)
return node
def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node:
seed = _get_seed_without_passphrase()
node = Slip21Node(seed)
node.derive_path(path)
return node
def remove_ed25519_prefix(pubkey: bytes) -> bytes:
# 0x01 prefix is not part of the actual public key, hence removed
return pubkey[1:]
def _path_hardened(path: list) -> bool:
return all(i & HARDENED for i in path)
if False:
from protobuf import MessageType
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]]
def with_slip44_keychain(
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])]
if allow_testnet:
namespaces.append((curve, [44 | HARDENED, 1 | HARDENED]))
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, namespaces)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
return decorator