mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 07:50:57 +00:00
feat(core): implement BIP-32 path schemas
This commit is contained in:
parent
e611a4a110
commit
7fe5c804ff
@ -1,7 +1,9 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip32
|
from trezor.crypto import bip32
|
||||||
|
|
||||||
from . import HARDENED, paths, safety_checks
|
from . import paths, safety_checks
|
||||||
from .seed import Slip21Node, get_seed
|
from .seed import Slip21Node, get_seed
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
@ -10,10 +12,12 @@ if False:
|
|||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
@ -85,13 +89,13 @@ class Keychain:
|
|||||||
self,
|
self,
|
||||||
seed: bytes,
|
seed: bytes,
|
||||||
curve: str,
|
curve: str,
|
||||||
namespaces: Sequence[paths.Bip32Path],
|
schemas: Iterable[paths.PathSchemaType],
|
||||||
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.curve = curve
|
self.curve = curve
|
||||||
self.namespaces = namespaces
|
self.schemas = tuple(schemas)
|
||||||
self.slip21_namespaces = slip21_namespaces
|
self.slip21_namespaces = tuple(slip21_namespaces)
|
||||||
|
|
||||||
self._cache = LRUCache(10)
|
self._cache = LRUCache(10)
|
||||||
|
|
||||||
@ -107,11 +111,14 @@ class Keychain:
|
|||||||
if not safety_checks.is_strict():
|
if not safety_checks.is_strict():
|
||||||
return
|
return
|
||||||
|
|
||||||
if any(ns == path[: len(ns)] for ns in self.namespaces):
|
if self.is_in_keychain(path):
|
||||||
return
|
return
|
||||||
|
|
||||||
raise FORBIDDEN_KEY_PATH
|
raise FORBIDDEN_KEY_PATH
|
||||||
|
|
||||||
|
def is_in_keychain(self, path: paths.Bip32Path) -> bool:
|
||||||
|
return any(schema.match(path) for schema in self.schemas)
|
||||||
|
|
||||||
def _derive_with_cache(
|
def _derive_with_cache(
|
||||||
self,
|
self,
|
||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
@ -159,27 +166,53 @@ class Keychain:
|
|||||||
async def get_keychain(
|
async def get_keychain(
|
||||||
ctx: wire.Context,
|
ctx: wire.Context,
|
||||||
curve: str,
|
curve: str,
|
||||||
namespaces: Sequence[paths.Bip32Path],
|
schemas: Iterable[paths.PathSchemaType],
|
||||||
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||||
) -> Keychain:
|
) -> Keychain:
|
||||||
seed = await get_seed(ctx)
|
seed = await get_seed(ctx)
|
||||||
keychain = Keychain(seed, curve, namespaces, slip21_namespaces)
|
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
||||||
return keychain
|
return keychain
|
||||||
|
|
||||||
|
|
||||||
def with_slip44_keychain(
|
def with_slip44_keychain(
|
||||||
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
|
*patterns: str,
|
||||||
|
slip44_id: int,
|
||||||
|
curve: str = "secp256k1",
|
||||||
|
allow_testnet: bool = True,
|
||||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
namespaces = [[44 | HARDENED, slip44 | HARDENED]]
|
if not patterns:
|
||||||
|
raise ValueError # specify a pattern
|
||||||
|
|
||||||
if allow_testnet:
|
if allow_testnet:
|
||||||
namespaces.append([44 | HARDENED, 1 | HARDENED])
|
slip44_ids: Union[int, Tuple[int, int]] = (slip44_id, 1)
|
||||||
|
else:
|
||||||
|
slip44_ids = slip44_id
|
||||||
|
|
||||||
|
schemas = []
|
||||||
|
for pattern in patterns:
|
||||||
|
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids))
|
||||||
|
|
||||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||||
keychain = await get_keychain(ctx, curve, namespaces)
|
keychain = await get_keychain(ctx, curve, schemas)
|
||||||
with keychain:
|
with keychain:
|
||||||
return await func(ctx, msg, keychain)
|
return await func(ctx, msg, keychain)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def auto_keychain(
|
||||||
|
modname: str, allow_testnet: bool = True
|
||||||
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
|
rdot = modname.rfind(".")
|
||||||
|
parent_modname = modname[:rdot]
|
||||||
|
parent_module = sys.modules[parent_modname]
|
||||||
|
|
||||||
|
pattern = getattr(parent_module, "PATTERN")
|
||||||
|
curve = getattr(parent_module, "CURVE")
|
||||||
|
slip44_id = getattr(parent_module, "SLIP44_ID")
|
||||||
|
return with_slip44_keychain(
|
||||||
|
pattern, slip44_id=slip44_id, curve=curve, allow_testnet=allow_testnet
|
||||||
|
)
|
||||||
|
@ -8,7 +8,17 @@ from . import HARDENED
|
|||||||
from .confirm import require_confirm
|
from .confirm import require_confirm
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Callable, List, Sequence, TypeVar
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from typing_extensions import Protocol
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
|
|
||||||
# XXX this is a circular import, but it's only for typing
|
# XXX this is a circular import, but it's only for typing
|
||||||
@ -18,17 +28,226 @@ if False:
|
|||||||
Slip21Path = Sequence[bytes]
|
Slip21Path = Sequence[bytes]
|
||||||
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
||||||
|
|
||||||
|
class PathSchemaType(Protocol):
|
||||||
|
def match(self, path: Bip32Path) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _FastInclusiveRange(range):
|
||||||
|
"""Inclusive range with a fast membership test, suitable for PathSchema use.
|
||||||
|
|
||||||
|
Micropython's `range` does not implement the `__contains__` method. This makes
|
||||||
|
checking whether `x in range(BIG_NUMBER)` slow. This class fixes the problem.
|
||||||
|
|
||||||
|
In addition, convenience modifications have been made:
|
||||||
|
* both `min` and `max` belong to the range (so `stop == max + 1`).
|
||||||
|
* both `min` and `max` must be set, `step` is not allowed (we don't need it and it
|
||||||
|
would make the `__contains__` method more complex)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min: int, max: int) -> None:
|
||||||
|
super().__init__(min, max + 1)
|
||||||
|
|
||||||
|
def __contains__(self, x: object) -> bool:
|
||||||
|
if not isinstance(x, int):
|
||||||
|
return False
|
||||||
|
return self.start <= x < self.stop
|
||||||
|
|
||||||
|
|
||||||
|
class PathSchema:
|
||||||
|
"""General BIP-32 path schema.
|
||||||
|
|
||||||
|
Loosely based on the BIP-32 path template proposal [1].
|
||||||
|
|
||||||
|
Each path component can be one of the following:
|
||||||
|
- constant, e.g., `7`
|
||||||
|
- list of constants, e.g., `[1,2,3]`
|
||||||
|
- range, e.g., `[0-19]`
|
||||||
|
|
||||||
|
Brackets are recommended but not enforced.
|
||||||
|
|
||||||
|
The following substitutions are available:
|
||||||
|
- `coin_type` is substituted with the coin's SLIP-44 identifier
|
||||||
|
- `account` is substituted with `[0-100]`, Trezor's default range of accounts
|
||||||
|
- `change` is substituted with `[0,1]`
|
||||||
|
- `address_index` is substituted with `[0-1000000]`, Trezor's default range of
|
||||||
|
addresses
|
||||||
|
|
||||||
|
Hardened flag is indicated by an apostrophe and applies to the whole path component.
|
||||||
|
It is impossible to specify both hardened and non-hardened values for the same
|
||||||
|
component.
|
||||||
|
|
||||||
|
See examples of valid path formats below and in `apps.bitcoin.keychain`.
|
||||||
|
|
||||||
|
E.g. the following are equivalent definitions of a BIP-84 schema:
|
||||||
|
|
||||||
|
m/84'/coin_type'/[0-100]'/[0,1]/[0-1000000]
|
||||||
|
m/84'/coin_type'/0-100'/0,1/0-1000000
|
||||||
|
m/84'/coin_type'/account'/change/address_index
|
||||||
|
|
||||||
|
Adding an asterisk at the end of the pattern acts as a wildcard for zero or more
|
||||||
|
path components:
|
||||||
|
- m/* can be followed by any number of _unhardened_ path components
|
||||||
|
- m/*' can be followed by any number of _hardened_ path components
|
||||||
|
- m/** can be followed by any number of _any_ path components
|
||||||
|
|
||||||
|
The following is a BIP-44 generic `GetPublicKey` schema:
|
||||||
|
|
||||||
|
m/44'/coin_type'/account'/*
|
||||||
|
|
||||||
|
The asterisk expression can only appear at end of pattern.
|
||||||
|
|
||||||
|
[1] https://github.com/dgpv/bip32_template_parse_tplaplus_spec/blob/master/bip-path-templates.mediawiki
|
||||||
|
"""
|
||||||
|
|
||||||
|
REPLACEMENTS = {
|
||||||
|
"account": "0-100",
|
||||||
|
"change": "0,1",
|
||||||
|
"address_index": "0-1000000",
|
||||||
|
}
|
||||||
|
|
||||||
|
WILDCARD_RANGES = {
|
||||||
|
"*": _FastInclusiveRange(0, HARDENED - 1),
|
||||||
|
"*'": _FastInclusiveRange(HARDENED, 0xFFFF_FFFF),
|
||||||
|
"**": _FastInclusiveRange(0, 0xFFFF_FFFF),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, pattern: str, slip44_id: Union[int, Iterable[int]]) -> None:
|
||||||
|
if not pattern.startswith("m/"):
|
||||||
|
raise ValueError # unsupported path template
|
||||||
|
components = pattern[2:].split("/")
|
||||||
|
|
||||||
|
if isinstance(slip44_id, int):
|
||||||
|
slip44_id = (slip44_id,)
|
||||||
|
|
||||||
|
self.schema: List[Collection[int]] = []
|
||||||
|
self.trailing_components: Collection[int] = ()
|
||||||
|
|
||||||
|
for component in components:
|
||||||
|
if component in self.WILDCARD_RANGES:
|
||||||
|
if len(self.schema) != len(components) - 1:
|
||||||
|
# every component should have resulted in extending self.schema
|
||||||
|
# so if self.schema does not have the appropriate length (yet),
|
||||||
|
# the asterisk is not the last item
|
||||||
|
raise ValueError # asterisk is not last item of pattern
|
||||||
|
|
||||||
|
self.trailing_components = self.WILDCARD_RANGES[component]
|
||||||
|
break
|
||||||
|
|
||||||
|
# figure out if the component is hardened
|
||||||
|
if component[-1] == "'":
|
||||||
|
component = component[:-1]
|
||||||
|
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731
|
||||||
|
else:
|
||||||
|
parse = int
|
||||||
|
|
||||||
|
# strip brackets
|
||||||
|
if component[0] == "[" and component[-1] == "]":
|
||||||
|
component = component[1:-1]
|
||||||
|
|
||||||
|
# optionally replace a keyword
|
||||||
|
component = self.REPLACEMENTS.get(component, component)
|
||||||
|
|
||||||
|
if "-" in component:
|
||||||
|
# parse as a range
|
||||||
|
a, b = [parse(s) for s in component.split("-", 1)]
|
||||||
|
self.schema.append(_FastInclusiveRange(a, b))
|
||||||
|
|
||||||
|
elif "," in component:
|
||||||
|
# parse as a list of values
|
||||||
|
self.schema.append(set(parse(s) for s in component.split(",")))
|
||||||
|
|
||||||
|
elif component == "coin_type":
|
||||||
|
# substitute SLIP-44 ids
|
||||||
|
self.schema.append(set(parse(s) for s in slip44_id))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# plain constant
|
||||||
|
self.schema.append((parse(component),))
|
||||||
|
|
||||||
|
def match(self, path: Bip32Path) -> bool:
|
||||||
|
# The path must not be _shorter_ than schema. It may be longer.
|
||||||
|
if len(path) < len(self.schema):
|
||||||
|
return False
|
||||||
|
|
||||||
|
path_iter = iter(path)
|
||||||
|
# iterate over length of schema, consuming path components
|
||||||
|
for expected in self.schema:
|
||||||
|
value = next(path_iter)
|
||||||
|
if value not in expected:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# iterate over remaining path components
|
||||||
|
for value in path_iter:
|
||||||
|
if value not in self.trailing_components:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
components = ["m"]
|
||||||
|
|
||||||
|
def unharden(item: int) -> int:
|
||||||
|
return item ^ (item & HARDENED)
|
||||||
|
|
||||||
|
for component in self.schema:
|
||||||
|
if isinstance(component, range):
|
||||||
|
a, b = component.start, component.stop - 1
|
||||||
|
components.append(
|
||||||
|
"[{}-{}]{}".format(
|
||||||
|
unharden(a), unharden(b), "'" if a & HARDENED else ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
component_str = ",".join(str(unharden(i)) for i in component)
|
||||||
|
if len(component) > 1:
|
||||||
|
component_str = "[" + component_str + "]"
|
||||||
|
if next(iter(component)) & HARDENED:
|
||||||
|
component_str += "'"
|
||||||
|
components.append(component_str)
|
||||||
|
|
||||||
|
if self.trailing_components:
|
||||||
|
for key, val in self.WILDCARD_RANGES.items():
|
||||||
|
if self.trailing_components is val:
|
||||||
|
components.append(key)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
components.append("???")
|
||||||
|
|
||||||
|
return "<schema:" + "/".join(components) + ">"
|
||||||
|
|
||||||
|
|
||||||
|
class _AlwaysMatchingSchema:
|
||||||
|
@staticmethod
|
||||||
|
def match(path: Bip32Path) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class _NeverMatchingSchema:
|
||||||
|
@staticmethod
|
||||||
|
def match(path: Bip32Path) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# type objects _AlwaysMatchingSchema and _NeverMatching schema conform to the
|
||||||
|
# PathSchemaType protocol, but mypy fails to recognize this due to:
|
||||||
|
# https://github.com/python/mypy/issues/4536,
|
||||||
|
# hence the following trickery
|
||||||
|
AlwaysMatchingSchema: PathSchemaType = _AlwaysMatchingSchema # type: ignore
|
||||||
|
NeverMatchingSchema: PathSchemaType = _NeverMatchingSchema # type: ignore
|
||||||
|
|
||||||
|
PATTERN_BIP44 = "m/44'/coin_type'/account'/change/address_index"
|
||||||
|
PATTERN_BIP44_PUBKEY = "m/44'/coin_type'/account'/*"
|
||||||
|
PATTERN_SEP5 = "m/44'/coin_type'/account'"
|
||||||
|
|
||||||
|
|
||||||
async def validate_path(
|
async def validate_path(
|
||||||
ctx: wire.Context,
|
ctx: wire.Context, keychain: Keychain, path: Bip32Path, *additional_checks: bool
|
||||||
validate_func: Callable[..., bool],
|
|
||||||
keychain: Keychain,
|
|
||||||
path: List[int],
|
|
||||||
curve: str,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
keychain.verify_path(path)
|
keychain.verify_path(path)
|
||||||
if not validate_func(path, **kwargs):
|
if not keychain.is_in_keychain(path) or not all(additional_checks):
|
||||||
await show_path_warning(ctx, path)
|
await show_path_warning(ctx, path)
|
||||||
|
|
||||||
|
|
||||||
@ -41,28 +260,6 @@ async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None:
|
|||||||
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
|
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
|
||||||
|
|
||||||
|
|
||||||
def validate_path_for_get_public_key(path: Bip32Path, slip44_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if path has at least three hardened items and slip44 id matches.
|
|
||||||
The path is allowed to have more than three items, but all the following
|
|
||||||
items have to be non-hardened.
|
|
||||||
"""
|
|
||||||
length = len(path)
|
|
||||||
if length < 3 or length > 5:
|
|
||||||
return False
|
|
||||||
if path[0] != 44 | HARDENED:
|
|
||||||
return False
|
|
||||||
if path[1] != slip44_id | HARDENED:
|
|
||||||
return False
|
|
||||||
if path[2] < HARDENED or path[2] > 20 | HARDENED:
|
|
||||||
return False
|
|
||||||
if length > 3 and is_hardened(path[3]):
|
|
||||||
return False
|
|
||||||
if length > 4 and is_hardened(path[4]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def is_hardened(i: int) -> bool:
|
def is_hardened(i: int) -> bool:
|
||||||
return bool(i & HARDENED)
|
return bool(i & HARDENED)
|
||||||
|
|
||||||
@ -71,15 +268,19 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
|
|||||||
return all(is_hardened(n) for n in address_n)
|
return all(is_hardened(n) for n in address_n)
|
||||||
|
|
||||||
|
|
||||||
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
|
def address_n_to_str(address_n: Bip32Path) -> str:
|
||||||
def path_item(i: int) -> str:
|
def path_item(i: int) -> str:
|
||||||
if i & HARDENED:
|
if i & HARDENED:
|
||||||
return str(i ^ HARDENED) + "'"
|
return str(i ^ HARDENED) + "'"
|
||||||
else:
|
else:
|
||||||
return str(i)
|
return str(i)
|
||||||
|
|
||||||
|
return "m/" + "/".join([path_item(i) for i in address_n])
|
||||||
|
|
||||||
|
|
||||||
|
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
|
||||||
lines = []
|
lines = []
|
||||||
path_str = "m/" + "/".join([path_item(i) for i in address_n])
|
path_str = address_n_to_str(address_n)
|
||||||
|
|
||||||
per_line = const(17)
|
per_line = const(17)
|
||||||
while len(path_str) > per_line:
|
while len(path_str) > per_line:
|
||||||
|
@ -4,8 +4,8 @@ from mock_storage import mock_storage
|
|||||||
|
|
||||||
from storage import cache
|
from storage import cache
|
||||||
import storage.device
|
import storage.device
|
||||||
from apps.common import HARDENED, safety_checks
|
from apps.common import safety_checks
|
||||||
from apps.common.paths import path_is_hardened
|
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened
|
||||||
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
@ -13,26 +13,31 @@ from trezor.messages import SafetyCheckLevel
|
|||||||
|
|
||||||
|
|
||||||
class TestKeychain(unittest.TestCase):
|
class TestKeychain(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
cache.start_session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cache.clear_all()
|
||||||
|
|
||||||
@mock_storage
|
@mock_storage
|
||||||
def test_verify_path(self):
|
def test_verify_path(self):
|
||||||
n = [
|
schemas = (
|
||||||
[44 | HARDENED, 134 | HARDENED],
|
PathSchema("m/44'/coin_type'", slip44_id=134),
|
||||||
[44 | HARDENED, 11 | HARDENED],
|
PathSchema("m/44'/coin_type'", slip44_id=11),
|
||||||
]
|
)
|
||||||
keychain = Keychain(b"", "secp256k1", n)
|
keychain = Keychain(b"", "secp256k1", schemas)
|
||||||
|
|
||||||
correct = (
|
correct = (
|
||||||
[44 | HARDENED, 134 | HARDENED],
|
[H_(44), H_(134)],
|
||||||
[44 | HARDENED, 11 | HARDENED],
|
[H_(44), H_(11)],
|
||||||
[44 | HARDENED, 11 | HARDENED, 12],
|
|
||||||
)
|
)
|
||||||
for path in correct:
|
for path in correct:
|
||||||
keychain.verify_path(path)
|
keychain.verify_path(path)
|
||||||
|
|
||||||
fails = (
|
fails = (
|
||||||
[44 | HARDENED, 134], # path does not match
|
[H_(44), 134], # path does not match
|
||||||
[44, 134], # path does not match (non-hardened items)
|
[44, 134], # path does not match (non-hardened items)
|
||||||
[44 | HARDENED, 13 | HARDENED], # invalid second item
|
[H_(44), H_(13)], # invalid second item
|
||||||
)
|
)
|
||||||
for f in fails:
|
for f in fails:
|
||||||
with self.assertRaises(wire.DataError):
|
with self.assertRaises(wire.DataError):
|
||||||
@ -42,43 +47,40 @@ class TestKeychain(unittest.TestCase):
|
|||||||
safety_checks.apply_setting(SafetyCheckLevel.PromptTemporarily)
|
safety_checks.apply_setting(SafetyCheckLevel.PromptTemporarily)
|
||||||
for path in correct + fails:
|
for path in correct + fails:
|
||||||
keychain.verify_path(path)
|
keychain.verify_path(path)
|
||||||
# turn on restrictions
|
|
||||||
safety_checks.apply_setting(SafetyCheckLevel.Strict)
|
|
||||||
|
|
||||||
def test_verify_path_special_ed25519(self):
|
def test_verify_path_special_ed25519(self):
|
||||||
n = [[44 | HARDENED, 134 | HARDENED]]
|
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134)
|
||||||
k = Keychain(b"", "ed25519-keccak", n)
|
k = Keychain(b"", "ed25519-keccak", [schema])
|
||||||
|
|
||||||
# OK case
|
# OK case
|
||||||
k.verify_path([44 | HARDENED, 134 | HARDENED])
|
k.verify_path([H_(44), H_(134)])
|
||||||
|
|
||||||
# failing case: non-hardened component with ed25519-like derivation
|
# failing case: non-hardened component with ed25519-like derivation
|
||||||
with self.assertRaises(wire.DataError):
|
with self.assertRaises(wire.DataError):
|
||||||
k.verify_path([44 | HARDENED, 134 | HARDENED, 1])
|
k.verify_path([H_(44), H_(134), 1])
|
||||||
|
|
||||||
def test_verify_path_empty_namespace(self):
|
def test_no_schemas(self):
|
||||||
k = Keychain(b"", "secp256k1", [[]])
|
k = Keychain(b"", "secp256k1", [])
|
||||||
correct = (
|
paths = (
|
||||||
[],
|
[],
|
||||||
[1, 2, 3, 4],
|
[1, 2, 3, 4],
|
||||||
[44 | HARDENED, 11 | HARDENED],
|
[H_(44), H_(11)],
|
||||||
[44 | HARDENED, 11 | HARDENED, 12],
|
[H_(44), H_(11), 12],
|
||||||
)
|
)
|
||||||
for c in correct:
|
for path in paths:
|
||||||
k.verify_path(c)
|
self.assertRaises(wire.DataError, k.verify_path, path)
|
||||||
|
|
||||||
def test_get_keychain(self):
|
def test_get_keychain(self):
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.start_session()
|
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache.set(cache.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
namespaces = [[44 | HARDENED]]
|
schema = PathSchema("m/44'/1'", 0)
|
||||||
keychain = await_result(
|
keychain = await_result(
|
||||||
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", namespaces)
|
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
|
||||||
)
|
)
|
||||||
|
|
||||||
# valid path:
|
# valid path:
|
||||||
self.assertIsNotNone(keychain.derive([44 | HARDENED, 1 | HARDENED]))
|
self.assertIsNotNone(keychain.derive([H_(44), H_(1)]))
|
||||||
|
|
||||||
# invalid path:
|
# invalid path:
|
||||||
with self.assertRaises(wire.DataError):
|
with self.assertRaises(wire.DataError):
|
||||||
@ -86,13 +88,12 @@ class TestKeychain(unittest.TestCase):
|
|||||||
|
|
||||||
def test_with_slip44(self):
|
def test_with_slip44(self):
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.start_session()
|
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache.set(cache.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
slip44_id = 42
|
slip44_id = 42
|
||||||
valid_path = [44 | HARDENED, slip44_id | HARDENED]
|
valid_path = [H_(44), H_(slip44_id), H_(0)]
|
||||||
invalid_path = [44 | HARDENED, 99 | HARDENED]
|
invalid_path = [H_(44), H_(99), H_(0)]
|
||||||
testnet_path = [44 | HARDENED, 1 | HARDENED]
|
testnet_path = [H_(44), H_(1), H_(0)]
|
||||||
|
|
||||||
def check_valid_paths(keychain, *paths):
|
def check_valid_paths(keychain, *paths):
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@ -102,24 +103,24 @@ class TestKeychain(unittest.TestCase):
|
|||||||
for path in paths:
|
for path in paths:
|
||||||
self.assertRaises(wire.DataError, keychain.derive, path)
|
self.assertRaises(wire.DataError, keychain.derive, path)
|
||||||
|
|
||||||
@with_slip44_keychain(slip44_id)
|
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id)
|
||||||
async def func_id_only(ctx, msg, keychain):
|
async def func_id_only(ctx, msg, keychain):
|
||||||
check_valid_paths(keychain, valid_path)
|
|
||||||
check_invalid_paths(keychain, testnet_path, invalid_path)
|
|
||||||
|
|
||||||
@with_slip44_keychain(slip44_id, allow_testnet=True)
|
|
||||||
async def func_allow_testnet(ctx, msg, keychain):
|
|
||||||
check_valid_paths(keychain, valid_path, testnet_path)
|
check_valid_paths(keychain, valid_path, testnet_path)
|
||||||
check_invalid_paths(keychain, invalid_path)
|
check_invalid_paths(keychain, invalid_path)
|
||||||
|
|
||||||
@with_slip44_keychain(slip44_id, curve="ed25519")
|
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, allow_testnet=False)
|
||||||
async def func_with_curve(ctx, msg, keychain):
|
async def func_disallow_testnet(ctx, msg, keychain):
|
||||||
self.assertEqual(keychain.curve, "ed25519")
|
|
||||||
check_valid_paths(keychain, valid_path)
|
check_valid_paths(keychain, valid_path)
|
||||||
check_invalid_paths(keychain, testnet_path, invalid_path)
|
check_invalid_paths(keychain, testnet_path, invalid_path)
|
||||||
|
|
||||||
|
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, curve="ed25519")
|
||||||
|
async def func_with_curve(ctx, msg, keychain):
|
||||||
|
self.assertEqual(keychain.curve, "ed25519")
|
||||||
|
check_valid_paths(keychain, valid_path, testnet_path)
|
||||||
|
check_invalid_paths(keychain, invalid_path)
|
||||||
|
|
||||||
await_result(func_id_only(wire.DUMMY_CONTEXT, None))
|
await_result(func_id_only(wire.DUMMY_CONTEXT, None))
|
||||||
await_result(func_allow_testnet(wire.DUMMY_CONTEXT, None))
|
await_result(func_disallow_testnet(wire.DUMMY_CONTEXT, None))
|
||||||
await_result(func_with_curve(wire.DUMMY_CONTEXT, None))
|
await_result(func_with_curve(wire.DUMMY_CONTEXT, None))
|
||||||
|
|
||||||
def test_lru_cache(self):
|
def test_lru_cache(self):
|
||||||
|
@ -1,54 +1,163 @@
|
|||||||
from common import *
|
from common import *
|
||||||
from apps.common import HARDENED
|
from trezor.utils import ensure
|
||||||
from apps.common.paths import validate_path_for_get_public_key, is_hardened, path_is_hardened
|
from apps.common.paths import *
|
||||||
|
|
||||||
|
|
||||||
class TestPaths(unittest.TestCase):
|
class TestPaths(unittest.TestCase):
|
||||||
|
|
||||||
def test_is_hardened(self):
|
def test_is_hardened(self):
|
||||||
self.assertTrue(is_hardened(44 | HARDENED))
|
self.assertTrue(is_hardened(H_(44)))
|
||||||
self.assertTrue(is_hardened(0 | HARDENED))
|
self.assertTrue(is_hardened(H_(0)))
|
||||||
self.assertTrue(is_hardened(99999 | HARDENED))
|
self.assertTrue(is_hardened(H_(99999)))
|
||||||
|
|
||||||
self.assertFalse(is_hardened(44))
|
self.assertFalse(is_hardened(44))
|
||||||
self.assertFalse(is_hardened(0))
|
self.assertFalse(is_hardened(0))
|
||||||
self.assertFalse(is_hardened(99999))
|
self.assertFalse(is_hardened(99999))
|
||||||
|
|
||||||
def test_path_is_hardened(self):
|
def test_path_is_hardened(self):
|
||||||
self.assertTrue(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED]))
|
self.assertTrue(path_is_hardened([H_(44), H_(1), H_(0)]))
|
||||||
self.assertTrue(path_is_hardened([0 | HARDENED, ]))
|
self.assertTrue(path_is_hardened([H_(0)]))
|
||||||
|
|
||||||
self.assertFalse(path_is_hardened([44, 44 | HARDENED, 0 | HARDENED]))
|
self.assertFalse(path_is_hardened([44, H_(44), H_(0)]))
|
||||||
self.assertFalse(path_is_hardened([0, ]))
|
self.assertFalse(path_is_hardened([0]))
|
||||||
self.assertFalse(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0]))
|
self.assertFalse(path_is_hardened([H_(44), H_(1), H_(0), H_(0), 0]))
|
||||||
|
|
||||||
def test_path_for_get_public_key(self):
|
|
||||||
# 44'/41'/0'
|
|
||||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED], 41))
|
|
||||||
# 44'/111'/0'
|
|
||||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 111 | HARDENED, 0 | HARDENED], 111))
|
|
||||||
# 44'/0'/0'/0
|
|
||||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], 0))
|
|
||||||
# 44'/0'/0'/0/0
|
|
||||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0, 0], 0))
|
|
||||||
|
|
||||||
# 44'/41'
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED], 41))
|
|
||||||
# 44'/41'/0
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0], 41))
|
|
||||||
# 44'/41'/0' slip44 mismatch
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED], 99))
|
|
||||||
# # 44'/41'/0'/0'
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED], 41))
|
|
||||||
# # 44'/41'/0'/0'/0
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], 41))
|
|
||||||
# # 44'/41'/0'/0'/0'
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0 | HARDENED], 41))
|
|
||||||
# # 44'/41'/0'/0/0/0
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0, 0, 0], 41))
|
|
||||||
# # 44'/41'/0'/0/0'
|
|
||||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0, 0 | HARDENED], 41))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class TestPathSchemas(unittest.TestCase):
|
||||||
|
def assertMatch(self, schema, path):
|
||||||
|
self.assertTrue(
|
||||||
|
schema.match(path),
|
||||||
|
"Expected schema {!r} to match path {}".format(
|
||||||
|
schema, address_n_to_str(path)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assertMismatch(self, schema, path):
|
||||||
|
self.assertFalse(
|
||||||
|
schema.match(path),
|
||||||
|
"Expected schema {!r} to not match path {}".format(
|
||||||
|
schema, address_n_to_str(path)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def assertEqualSchema(self, schema_a, schema_b):
|
||||||
|
def is_equal(a, b):
|
||||||
|
if isinstance(a, range) and isinstance(b, range):
|
||||||
|
return a.start == b.start and a.step == b.step and a.stop == b.stop
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
ensure(
|
||||||
|
all(is_equal(a, b) for a, b in zip(schema_a.schema, schema_b.schema))
|
||||||
|
and is_equal(schema_a.trailing_components, schema_b.trailing_components),
|
||||||
|
"Schemas differ:\nA = {!r}\nB = {!r}".format(schema_a, schema_b),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_always_never_matching(self):
|
||||||
|
paths = [
|
||||||
|
[],
|
||||||
|
[0],
|
||||||
|
[H_(0)],
|
||||||
|
[44],
|
||||||
|
[H_(44)],
|
||||||
|
[H_(44), H_(0), H_(0), 0, 0],
|
||||||
|
[H_(44), H_(0), H_(0), H_(0), H_(0)],
|
||||||
|
[H_(44), H_(0), H_(0), H_(0), H_(0)] * 10,
|
||||||
|
]
|
||||||
|
for path in paths:
|
||||||
|
self.assertMatch(SCHEMA_ANY_PATH, path)
|
||||||
|
self.assertMismatch(SCHEMA_NO_MATCH, path)
|
||||||
|
|
||||||
|
def test_pattern_fixed(self):
|
||||||
|
pattern = "m/44'/0'/0'/0/0"
|
||||||
|
schema = PathSchema(pattern, 0)
|
||||||
|
|
||||||
|
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
[],
|
||||||
|
[0],
|
||||||
|
[H_(0)],
|
||||||
|
[44],
|
||||||
|
[H_(44)],
|
||||||
|
[44, 0, 0, 0, 0],
|
||||||
|
[H_(44), H_(0), H_(0), H_(0), H_(0)],
|
||||||
|
[H_(44), H_(0), H_(0), H_(0), H_(0)] * 10,
|
||||||
|
]
|
||||||
|
for path in paths:
|
||||||
|
self.assertMismatch(schema, path)
|
||||||
|
|
||||||
|
def test_ranges_sets(self):
|
||||||
|
pattern_ranges = "m/44'/[100-109]'/[0-20]"
|
||||||
|
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
|
||||||
|
schema_ranges = PathSchema(pattern_ranges, 0)
|
||||||
|
schema_sets = PathSchema(pattern_sets, 0)
|
||||||
|
|
||||||
|
paths_good = [
|
||||||
|
[H_(44), H_(100), 0],
|
||||||
|
[H_(44), H_(100), 10],
|
||||||
|
[H_(44), H_(100), 20],
|
||||||
|
[H_(44), H_(105), 0],
|
||||||
|
[H_(44), H_(105), 10],
|
||||||
|
[H_(44), H_(105), 20],
|
||||||
|
[H_(44), H_(109), 0],
|
||||||
|
[H_(44), H_(109), 10],
|
||||||
|
[H_(44), H_(109), 20],
|
||||||
|
]
|
||||||
|
for path in paths_good:
|
||||||
|
self.assertMatch(schema_ranges, path)
|
||||||
|
self.assertMatch(schema_sets, path)
|
||||||
|
|
||||||
|
paths_bad = [
|
||||||
|
[H_(44), H_(100)],
|
||||||
|
[H_(44), H_(100), 0, 0],
|
||||||
|
[H_(44), 100, 0],
|
||||||
|
[H_(44), 100, H_(0)],
|
||||||
|
[H_(44), H_(99), 0],
|
||||||
|
[H_(44), H_(110), 0],
|
||||||
|
[H_(44), H_(100), 21],
|
||||||
|
]
|
||||||
|
for path in paths_bad:
|
||||||
|
self.assertMismatch(schema_ranges, path)
|
||||||
|
self.assertMismatch(schema_sets, path)
|
||||||
|
|
||||||
|
self.assertMatch(schema_ranges, [H_(44), H_(104), 19])
|
||||||
|
self.assertMismatch(schema_sets, [H_(44), H_(104), 19])
|
||||||
|
|
||||||
|
def test_brackets(self):
|
||||||
|
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
|
||||||
|
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
|
||||||
|
schema_a = PathSchema(pattern_a, 0)
|
||||||
|
schema_b = PathSchema(pattern_b, 0)
|
||||||
|
self.assertEqualSchema(schema_a, schema_b)
|
||||||
|
|
||||||
|
def test_wildcard(self):
|
||||||
|
pattern = "m/44'/0'/*"
|
||||||
|
schema = PathSchema(pattern, 0)
|
||||||
|
|
||||||
|
paths_good = [
|
||||||
|
[H_(44), H_(0)],
|
||||||
|
[H_(44), H_(0), 0],
|
||||||
|
[H_(44), H_(0), 0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
]
|
||||||
|
for path in paths_good:
|
||||||
|
self.assertMatch(schema, path)
|
||||||
|
|
||||||
|
paths_bad = [
|
||||||
|
[H_(44)],
|
||||||
|
[H_(44), H_(0), H_(0)],
|
||||||
|
[H_(44), H_(0), 0, 1, 2, 3, 4, 5, 6, 7, H_(8)],
|
||||||
|
]
|
||||||
|
for path in paths_bad:
|
||||||
|
self.assertMismatch(schema, path)
|
||||||
|
|
||||||
|
def test_substitutes(self):
|
||||||
|
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
|
||||||
|
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
|
||||||
|
schema_sub = PathSchema(pattern_sub, slip44_id=19)
|
||||||
|
# use wrong slip44 id to ensure it doesn't affect anything
|
||||||
|
schema_plain = PathSchema(pattern_plain, slip44_id=0)
|
||||||
|
|
||||||
|
self.assertEqualSchema(schema_sub, schema_plain)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user