mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 07:50:57 +00:00
feat(core): implement BIP-32 path schemas
This commit is contained in:
parent
e611a4a110
commit
7fe5c804ff
@ -1,7 +1,9 @@
|
||||
import sys
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip32
|
||||
|
||||
from . import HARDENED, paths, safety_checks
|
||||
from . import paths, safety_checks
|
||||
from .seed import Slip21Node, get_seed
|
||||
|
||||
if False:
|
||||
@ -10,10 +12,12 @@ if False:
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
@ -85,13 +89,13 @@ class Keychain:
|
||||
self,
|
||||
seed: bytes,
|
||||
curve: str,
|
||||
namespaces: Sequence[paths.Bip32Path],
|
||||
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
||||
schemas: Iterable[paths.PathSchemaType],
|
||||
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||
) -> None:
|
||||
self.seed = seed
|
||||
self.curve = curve
|
||||
self.namespaces = namespaces
|
||||
self.slip21_namespaces = slip21_namespaces
|
||||
self.schemas = tuple(schemas)
|
||||
self.slip21_namespaces = tuple(slip21_namespaces)
|
||||
|
||||
self._cache = LRUCache(10)
|
||||
|
||||
@ -107,11 +111,14 @@ class Keychain:
|
||||
if not safety_checks.is_strict():
|
||||
return
|
||||
|
||||
if any(ns == path[: len(ns)] for ns in self.namespaces):
|
||||
if self.is_in_keychain(path):
|
||||
return
|
||||
|
||||
raise FORBIDDEN_KEY_PATH
|
||||
|
||||
def is_in_keychain(self, path: paths.Bip32Path) -> bool:
|
||||
return any(schema.match(path) for schema in self.schemas)
|
||||
|
||||
def _derive_with_cache(
|
||||
self,
|
||||
prefix_len: int,
|
||||
@ -159,27 +166,53 @@ class Keychain:
|
||||
async def get_keychain(
|
||||
ctx: wire.Context,
|
||||
curve: str,
|
||||
namespaces: Sequence[paths.Bip32Path],
|
||||
slip21_namespaces: Sequence[paths.Slip21Path] = (),
|
||||
schemas: Iterable[paths.PathSchemaType],
|
||||
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||
) -> Keychain:
|
||||
seed = await get_seed(ctx)
|
||||
keychain = Keychain(seed, curve, namespaces, slip21_namespaces)
|
||||
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
||||
return keychain
|
||||
|
||||
|
||||
def with_slip44_keychain(
|
||||
slip44: int, curve: str = "secp256k1", allow_testnet: bool = False
|
||||
*patterns: str,
|
||||
slip44_id: int,
|
||||
curve: str = "secp256k1",
|
||||
allow_testnet: bool = True,
|
||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||
namespaces = [[44 | HARDENED, slip44 | HARDENED]]
|
||||
if not patterns:
|
||||
raise ValueError # specify a pattern
|
||||
|
||||
if allow_testnet:
|
||||
namespaces.append([44 | HARDENED, 1 | HARDENED])
|
||||
slip44_ids: Union[int, Tuple[int, int]] = (slip44_id, 1)
|
||||
else:
|
||||
slip44_ids = slip44_id
|
||||
|
||||
schemas = []
|
||||
for pattern in patterns:
|
||||
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids))
|
||||
|
||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
keychain = await get_keychain(ctx, curve, namespaces)
|
||||
keychain = await get_keychain(ctx, curve, schemas)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def auto_keychain(
|
||||
modname: str, allow_testnet: bool = True
|
||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||
rdot = modname.rfind(".")
|
||||
parent_modname = modname[:rdot]
|
||||
parent_module = sys.modules[parent_modname]
|
||||
|
||||
pattern = getattr(parent_module, "PATTERN")
|
||||
curve = getattr(parent_module, "CURVE")
|
||||
slip44_id = getattr(parent_module, "SLIP44_ID")
|
||||
return with_slip44_keychain(
|
||||
pattern, slip44_id=slip44_id, curve=curve, allow_testnet=allow_testnet
|
||||
)
|
||||
|
@ -8,7 +8,17 @@ from . import HARDENED
|
||||
from .confirm import require_confirm
|
||||
|
||||
if False:
|
||||
from typing import Any, Callable, List, Sequence, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
from trezor import wire
|
||||
|
||||
# XXX this is a circular import, but it's only for typing
|
||||
@ -18,17 +28,226 @@ if False:
|
||||
Slip21Path = Sequence[bytes]
|
||||
PathType = TypeVar("PathType", Bip32Path, Slip21Path)
|
||||
|
||||
class PathSchemaType(Protocol):
|
||||
def match(self, path: Bip32Path) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class _FastInclusiveRange(range):
|
||||
"""Inclusive range with a fast membership test, suitable for PathSchema use.
|
||||
|
||||
Micropython's `range` does not implement the `__contains__` method. This makes
|
||||
checking whether `x in range(BIG_NUMBER)` slow. This class fixes the problem.
|
||||
|
||||
In addition, convenience modifications have been made:
|
||||
* both `min` and `max` belong to the range (so `stop == max + 1`).
|
||||
* both `min` and `max` must be set, `step` is not allowed (we don't need it and it
|
||||
would make the `__contains__` method more complex)
|
||||
"""
|
||||
|
||||
def __init__(self, min: int, max: int) -> None:
|
||||
super().__init__(min, max + 1)
|
||||
|
||||
def __contains__(self, x: object) -> bool:
|
||||
if not isinstance(x, int):
|
||||
return False
|
||||
return self.start <= x < self.stop
|
||||
|
||||
|
||||
class PathSchema:
|
||||
"""General BIP-32 path schema.
|
||||
|
||||
Loosely based on the BIP-32 path template proposal [1].
|
||||
|
||||
Each path component can be one of the following:
|
||||
- constant, e.g., `7`
|
||||
- list of constants, e.g., `[1,2,3]`
|
||||
- range, e.g., `[0-19]`
|
||||
|
||||
Brackets are recommended but not enforced.
|
||||
|
||||
The following substitutions are available:
|
||||
- `coin_type` is substituted with the coin's SLIP-44 identifier
|
||||
- `account` is substituted with `[0-100]`, Trezor's default range of accounts
|
||||
- `change` is substituted with `[0,1]`
|
||||
- `address_index` is substituted with `[0-1000000]`, Trezor's default range of
|
||||
addresses
|
||||
|
||||
Hardened flag is indicated by an apostrophe and applies to the whole path component.
|
||||
It is impossible to specify both hardened and non-hardened values for the same
|
||||
component.
|
||||
|
||||
See examples of valid path formats below and in `apps.bitcoin.keychain`.
|
||||
|
||||
E.g. the following are equivalent definitions of a BIP-84 schema:
|
||||
|
||||
m/84'/coin_type'/[0-100]'/[0,1]/[0-1000000]
|
||||
m/84'/coin_type'/0-100'/0,1/0-1000000
|
||||
m/84'/coin_type'/account'/change/address_index
|
||||
|
||||
Adding an asterisk at the end of the pattern acts as a wildcard for zero or more
|
||||
path components:
|
||||
- m/* can be followed by any number of _unhardened_ path components
|
||||
- m/*' can be followed by any number of _hardened_ path components
|
||||
- m/** can be followed by any number of _any_ path components
|
||||
|
||||
The following is a BIP-44 generic `GetPublicKey` schema:
|
||||
|
||||
m/44'/coin_type'/account'/*
|
||||
|
||||
The asterisk expression can only appear at end of pattern.
|
||||
|
||||
[1] https://github.com/dgpv/bip32_template_parse_tplaplus_spec/blob/master/bip-path-templates.mediawiki
|
||||
"""
|
||||
|
||||
REPLACEMENTS = {
|
||||
"account": "0-100",
|
||||
"change": "0,1",
|
||||
"address_index": "0-1000000",
|
||||
}
|
||||
|
||||
WILDCARD_RANGES = {
|
||||
"*": _FastInclusiveRange(0, HARDENED - 1),
|
||||
"*'": _FastInclusiveRange(HARDENED, 0xFFFF_FFFF),
|
||||
"**": _FastInclusiveRange(0, 0xFFFF_FFFF),
|
||||
}
|
||||
|
||||
def __init__(self, pattern: str, slip44_id: Union[int, Iterable[int]]) -> None:
|
||||
if not pattern.startswith("m/"):
|
||||
raise ValueError # unsupported path template
|
||||
components = pattern[2:].split("/")
|
||||
|
||||
if isinstance(slip44_id, int):
|
||||
slip44_id = (slip44_id,)
|
||||
|
||||
self.schema: List[Collection[int]] = []
|
||||
self.trailing_components: Collection[int] = ()
|
||||
|
||||
for component in components:
|
||||
if component in self.WILDCARD_RANGES:
|
||||
if len(self.schema) != len(components) - 1:
|
||||
# every component should have resulted in extending self.schema
|
||||
# so if self.schema does not have the appropriate length (yet),
|
||||
# the asterisk is not the last item
|
||||
raise ValueError # asterisk is not last item of pattern
|
||||
|
||||
self.trailing_components = self.WILDCARD_RANGES[component]
|
||||
break
|
||||
|
||||
# figure out if the component is hardened
|
||||
if component[-1] == "'":
|
||||
component = component[:-1]
|
||||
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731
|
||||
else:
|
||||
parse = int
|
||||
|
||||
# strip brackets
|
||||
if component[0] == "[" and component[-1] == "]":
|
||||
component = component[1:-1]
|
||||
|
||||
# optionally replace a keyword
|
||||
component = self.REPLACEMENTS.get(component, component)
|
||||
|
||||
if "-" in component:
|
||||
# parse as a range
|
||||
a, b = [parse(s) for s in component.split("-", 1)]
|
||||
self.schema.append(_FastInclusiveRange(a, b))
|
||||
|
||||
elif "," in component:
|
||||
# parse as a list of values
|
||||
self.schema.append(set(parse(s) for s in component.split(",")))
|
||||
|
||||
elif component == "coin_type":
|
||||
# substitute SLIP-44 ids
|
||||
self.schema.append(set(parse(s) for s in slip44_id))
|
||||
|
||||
else:
|
||||
# plain constant
|
||||
self.schema.append((parse(component),))
|
||||
|
||||
def match(self, path: Bip32Path) -> bool:
|
||||
# The path must not be _shorter_ than schema. It may be longer.
|
||||
if len(path) < len(self.schema):
|
||||
return False
|
||||
|
||||
path_iter = iter(path)
|
||||
# iterate over length of schema, consuming path components
|
||||
for expected in self.schema:
|
||||
value = next(path_iter)
|
||||
if value not in expected:
|
||||
return False
|
||||
|
||||
# iterate over remaining path components
|
||||
for value in path_iter:
|
||||
if value not in self.trailing_components:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __debug__:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
components = ["m"]
|
||||
|
||||
def unharden(item: int) -> int:
|
||||
return item ^ (item & HARDENED)
|
||||
|
||||
for component in self.schema:
|
||||
if isinstance(component, range):
|
||||
a, b = component.start, component.stop - 1
|
||||
components.append(
|
||||
"[{}-{}]{}".format(
|
||||
unharden(a), unharden(b), "'" if a & HARDENED else ""
|
||||
)
|
||||
)
|
||||
else:
|
||||
component_str = ",".join(str(unharden(i)) for i in component)
|
||||
if len(component) > 1:
|
||||
component_str = "[" + component_str + "]"
|
||||
if next(iter(component)) & HARDENED:
|
||||
component_str += "'"
|
||||
components.append(component_str)
|
||||
|
||||
if self.trailing_components:
|
||||
for key, val in self.WILDCARD_RANGES.items():
|
||||
if self.trailing_components is val:
|
||||
components.append(key)
|
||||
break
|
||||
else:
|
||||
components.append("???")
|
||||
|
||||
return "<schema:" + "/".join(components) + ">"
|
||||
|
||||
|
||||
class _AlwaysMatchingSchema:
|
||||
@staticmethod
|
||||
def match(path: Bip32Path) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class _NeverMatchingSchema:
|
||||
@staticmethod
|
||||
def match(path: Bip32Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# type objects _AlwaysMatchingSchema and _NeverMatching schema conform to the
|
||||
# PathSchemaType protocol, but mypy fails to recognize this due to:
|
||||
# https://github.com/python/mypy/issues/4536,
|
||||
# hence the following trickery
|
||||
AlwaysMatchingSchema: PathSchemaType = _AlwaysMatchingSchema # type: ignore
|
||||
NeverMatchingSchema: PathSchemaType = _NeverMatchingSchema # type: ignore
|
||||
|
||||
PATTERN_BIP44 = "m/44'/coin_type'/account'/change/address_index"
|
||||
PATTERN_BIP44_PUBKEY = "m/44'/coin_type'/account'/*"
|
||||
PATTERN_SEP5 = "m/44'/coin_type'/account'"
|
||||
|
||||
|
||||
async def validate_path(
|
||||
ctx: wire.Context,
|
||||
validate_func: Callable[..., bool],
|
||||
keychain: Keychain,
|
||||
path: List[int],
|
||||
curve: str,
|
||||
**kwargs: Any,
|
||||
ctx: wire.Context, keychain: Keychain, path: Bip32Path, *additional_checks: bool
|
||||
) -> None:
|
||||
keychain.verify_path(path)
|
||||
if not validate_func(path, **kwargs):
|
||||
if not keychain.is_in_keychain(path) or not all(additional_checks):
|
||||
await show_path_warning(ctx, path)
|
||||
|
||||
|
||||
@ -41,28 +260,6 @@ async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None:
|
||||
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
|
||||
|
||||
|
||||
def validate_path_for_get_public_key(path: Bip32Path, slip44_id: int) -> bool:
|
||||
"""
|
||||
Checks if path has at least three hardened items and slip44 id matches.
|
||||
The path is allowed to have more than three items, but all the following
|
||||
items have to be non-hardened.
|
||||
"""
|
||||
length = len(path)
|
||||
if length < 3 or length > 5:
|
||||
return False
|
||||
if path[0] != 44 | HARDENED:
|
||||
return False
|
||||
if path[1] != slip44_id | HARDENED:
|
||||
return False
|
||||
if path[2] < HARDENED or path[2] > 20 | HARDENED:
|
||||
return False
|
||||
if length > 3 and is_hardened(path[3]):
|
||||
return False
|
||||
if length > 4 and is_hardened(path[4]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_hardened(i: int) -> bool:
|
||||
return bool(i & HARDENED)
|
||||
|
||||
@ -71,15 +268,19 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
|
||||
return all(is_hardened(n) for n in address_n)
|
||||
|
||||
|
||||
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
|
||||
def address_n_to_str(address_n: Bip32Path) -> str:
|
||||
def path_item(i: int) -> str:
|
||||
if i & HARDENED:
|
||||
return str(i ^ HARDENED) + "'"
|
||||
else:
|
||||
return str(i)
|
||||
|
||||
return "m/" + "/".join([path_item(i) for i in address_n])
|
||||
|
||||
|
||||
def break_address_n_to_lines(address_n: Bip32Path) -> List[str]:
|
||||
lines = []
|
||||
path_str = "m/" + "/".join([path_item(i) for i in address_n])
|
||||
path_str = address_n_to_str(address_n)
|
||||
|
||||
per_line = const(17)
|
||||
while len(path_str) > per_line:
|
||||
|
@ -4,8 +4,8 @@ from mock_storage import mock_storage
|
||||
|
||||
from storage import cache
|
||||
import storage.device
|
||||
from apps.common import HARDENED, safety_checks
|
||||
from apps.common.paths import path_is_hardened
|
||||
from apps.common import safety_checks
|
||||
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened
|
||||
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip39
|
||||
@ -13,26 +13,31 @@ from trezor.messages import SafetyCheckLevel
|
||||
|
||||
|
||||
class TestKeychain(unittest.TestCase):
|
||||
def setUp(self):
|
||||
cache.start_session()
|
||||
|
||||
def tearDown(self):
|
||||
cache.clear_all()
|
||||
|
||||
@mock_storage
|
||||
def test_verify_path(self):
|
||||
n = [
|
||||
[44 | HARDENED, 134 | HARDENED],
|
||||
[44 | HARDENED, 11 | HARDENED],
|
||||
]
|
||||
keychain = Keychain(b"", "secp256k1", n)
|
||||
schemas = (
|
||||
PathSchema("m/44'/coin_type'", slip44_id=134),
|
||||
PathSchema("m/44'/coin_type'", slip44_id=11),
|
||||
)
|
||||
keychain = Keychain(b"", "secp256k1", schemas)
|
||||
|
||||
correct = (
|
||||
[44 | HARDENED, 134 | HARDENED],
|
||||
[44 | HARDENED, 11 | HARDENED],
|
||||
[44 | HARDENED, 11 | HARDENED, 12],
|
||||
[H_(44), H_(134)],
|
||||
[H_(44), H_(11)],
|
||||
)
|
||||
for path in correct:
|
||||
keychain.verify_path(path)
|
||||
|
||||
fails = (
|
||||
[44 | HARDENED, 134], # path does not match
|
||||
[H_(44), 134], # path does not match
|
||||
[44, 134], # path does not match (non-hardened items)
|
||||
[44 | HARDENED, 13 | HARDENED], # invalid second item
|
||||
[H_(44), H_(13)], # invalid second item
|
||||
)
|
||||
for f in fails:
|
||||
with self.assertRaises(wire.DataError):
|
||||
@ -42,43 +47,40 @@ class TestKeychain(unittest.TestCase):
|
||||
safety_checks.apply_setting(SafetyCheckLevel.PromptTemporarily)
|
||||
for path in correct + fails:
|
||||
keychain.verify_path(path)
|
||||
# turn on restrictions
|
||||
safety_checks.apply_setting(SafetyCheckLevel.Strict)
|
||||
|
||||
def test_verify_path_special_ed25519(self):
|
||||
n = [[44 | HARDENED, 134 | HARDENED]]
|
||||
k = Keychain(b"", "ed25519-keccak", n)
|
||||
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134)
|
||||
k = Keychain(b"", "ed25519-keccak", [schema])
|
||||
|
||||
# OK case
|
||||
k.verify_path([44 | HARDENED, 134 | HARDENED])
|
||||
k.verify_path([H_(44), H_(134)])
|
||||
|
||||
# failing case: non-hardened component with ed25519-like derivation
|
||||
with self.assertRaises(wire.DataError):
|
||||
k.verify_path([44 | HARDENED, 134 | HARDENED, 1])
|
||||
k.verify_path([H_(44), H_(134), 1])
|
||||
|
||||
def test_verify_path_empty_namespace(self):
|
||||
k = Keychain(b"", "secp256k1", [[]])
|
||||
correct = (
|
||||
def test_no_schemas(self):
|
||||
k = Keychain(b"", "secp256k1", [])
|
||||
paths = (
|
||||
[],
|
||||
[1, 2, 3, 4],
|
||||
[44 | HARDENED, 11 | HARDENED],
|
||||
[44 | HARDENED, 11 | HARDENED, 12],
|
||||
[H_(44), H_(11)],
|
||||
[H_(44), H_(11), 12],
|
||||
)
|
||||
for c in correct:
|
||||
k.verify_path(c)
|
||||
for path in paths:
|
||||
self.assertRaises(wire.DataError, k.verify_path, path)
|
||||
|
||||
def test_get_keychain(self):
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.start_session()
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
|
||||
namespaces = [[44 | HARDENED]]
|
||||
schema = PathSchema("m/44'/1'", 0)
|
||||
keychain = await_result(
|
||||
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", namespaces)
|
||||
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
|
||||
)
|
||||
|
||||
# valid path:
|
||||
self.assertIsNotNone(keychain.derive([44 | HARDENED, 1 | HARDENED]))
|
||||
self.assertIsNotNone(keychain.derive([H_(44), H_(1)]))
|
||||
|
||||
# invalid path:
|
||||
with self.assertRaises(wire.DataError):
|
||||
@ -86,13 +88,12 @@ class TestKeychain(unittest.TestCase):
|
||||
|
||||
def test_with_slip44(self):
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.start_session()
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
|
||||
slip44_id = 42
|
||||
valid_path = [44 | HARDENED, slip44_id | HARDENED]
|
||||
invalid_path = [44 | HARDENED, 99 | HARDENED]
|
||||
testnet_path = [44 | HARDENED, 1 | HARDENED]
|
||||
valid_path = [H_(44), H_(slip44_id), H_(0)]
|
||||
invalid_path = [H_(44), H_(99), H_(0)]
|
||||
testnet_path = [H_(44), H_(1), H_(0)]
|
||||
|
||||
def check_valid_paths(keychain, *paths):
|
||||
for path in paths:
|
||||
@ -102,24 +103,24 @@ class TestKeychain(unittest.TestCase):
|
||||
for path in paths:
|
||||
self.assertRaises(wire.DataError, keychain.derive, path)
|
||||
|
||||
@with_slip44_keychain(slip44_id)
|
||||
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id)
|
||||
async def func_id_only(ctx, msg, keychain):
|
||||
check_valid_paths(keychain, valid_path)
|
||||
check_invalid_paths(keychain, testnet_path, invalid_path)
|
||||
|
||||
@with_slip44_keychain(slip44_id, allow_testnet=True)
|
||||
async def func_allow_testnet(ctx, msg, keychain):
|
||||
check_valid_paths(keychain, valid_path, testnet_path)
|
||||
check_invalid_paths(keychain, invalid_path)
|
||||
|
||||
@with_slip44_keychain(slip44_id, curve="ed25519")
|
||||
async def func_with_curve(ctx, msg, keychain):
|
||||
self.assertEqual(keychain.curve, "ed25519")
|
||||
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, allow_testnet=False)
|
||||
async def func_disallow_testnet(ctx, msg, keychain):
|
||||
check_valid_paths(keychain, valid_path)
|
||||
check_invalid_paths(keychain, testnet_path, invalid_path)
|
||||
|
||||
@with_slip44_keychain(PATTERN_SEP5, slip44_id=slip44_id, curve="ed25519")
|
||||
async def func_with_curve(ctx, msg, keychain):
|
||||
self.assertEqual(keychain.curve, "ed25519")
|
||||
check_valid_paths(keychain, valid_path, testnet_path)
|
||||
check_invalid_paths(keychain, invalid_path)
|
||||
|
||||
await_result(func_id_only(wire.DUMMY_CONTEXT, None))
|
||||
await_result(func_allow_testnet(wire.DUMMY_CONTEXT, None))
|
||||
await_result(func_disallow_testnet(wire.DUMMY_CONTEXT, None))
|
||||
await_result(func_with_curve(wire.DUMMY_CONTEXT, None))
|
||||
|
||||
def test_lru_cache(self):
|
||||
|
@ -1,54 +1,163 @@
|
||||
from common import *
|
||||
from apps.common import HARDENED
|
||||
from apps.common.paths import validate_path_for_get_public_key, is_hardened, path_is_hardened
|
||||
from trezor.utils import ensure
|
||||
from apps.common.paths import *
|
||||
|
||||
|
||||
class TestPaths(unittest.TestCase):
|
||||
|
||||
def test_is_hardened(self):
|
||||
self.assertTrue(is_hardened(44 | HARDENED))
|
||||
self.assertTrue(is_hardened(0 | HARDENED))
|
||||
self.assertTrue(is_hardened(99999 | HARDENED))
|
||||
self.assertTrue(is_hardened(H_(44)))
|
||||
self.assertTrue(is_hardened(H_(0)))
|
||||
self.assertTrue(is_hardened(H_(99999)))
|
||||
|
||||
self.assertFalse(is_hardened(44))
|
||||
self.assertFalse(is_hardened(0))
|
||||
self.assertFalse(is_hardened(99999))
|
||||
|
||||
def test_path_is_hardened(self):
|
||||
self.assertTrue(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED]))
|
||||
self.assertTrue(path_is_hardened([0 | HARDENED, ]))
|
||||
self.assertTrue(path_is_hardened([H_(44), H_(1), H_(0)]))
|
||||
self.assertTrue(path_is_hardened([H_(0)]))
|
||||
|
||||
self.assertFalse(path_is_hardened([44, 44 | HARDENED, 0 | HARDENED]))
|
||||
self.assertFalse(path_is_hardened([0, ]))
|
||||
self.assertFalse(path_is_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0]))
|
||||
|
||||
def test_path_for_get_public_key(self):
|
||||
# 44'/41'/0'
|
||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED], 41))
|
||||
# 44'/111'/0'
|
||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 111 | HARDENED, 0 | HARDENED], 111))
|
||||
# 44'/0'/0'/0
|
||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], 0))
|
||||
# 44'/0'/0'/0/0
|
||||
self.assertTrue(validate_path_for_get_public_key([44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0, 0], 0))
|
||||
|
||||
# 44'/41'
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED], 41))
|
||||
# 44'/41'/0
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0], 41))
|
||||
# 44'/41'/0' slip44 mismatch
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED], 99))
|
||||
# # 44'/41'/0'/0'
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED], 41))
|
||||
# # 44'/41'/0'/0'/0
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], 41))
|
||||
# # 44'/41'/0'/0'/0'
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0 | HARDENED], 41))
|
||||
# # 44'/41'/0'/0/0/0
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0, 0, 0], 41))
|
||||
# # 44'/41'/0'/0/0'
|
||||
self.assertFalse(validate_path_for_get_public_key([44 | HARDENED, 41 | HARDENED, 0 | HARDENED, 0, 0 | HARDENED], 41))
|
||||
self.assertFalse(path_is_hardened([44, H_(44), H_(0)]))
|
||||
self.assertFalse(path_is_hardened([0]))
|
||||
self.assertFalse(path_is_hardened([H_(44), H_(1), H_(0), H_(0), 0]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class TestPathSchemas(unittest.TestCase):
|
||||
def assertMatch(self, schema, path):
|
||||
self.assertTrue(
|
||||
schema.match(path),
|
||||
"Expected schema {!r} to match path {}".format(
|
||||
schema, address_n_to_str(path)
|
||||
),
|
||||
)
|
||||
|
||||
def assertMismatch(self, schema, path):
|
||||
self.assertFalse(
|
||||
schema.match(path),
|
||||
"Expected schema {!r} to not match path {}".format(
|
||||
schema, address_n_to_str(path)
|
||||
),
|
||||
)
|
||||
|
||||
def assertEqualSchema(self, schema_a, schema_b):
|
||||
def is_equal(a, b):
|
||||
if isinstance(a, range) and isinstance(b, range):
|
||||
return a.start == b.start and a.step == b.step and a.stop == b.stop
|
||||
return a == b
|
||||
|
||||
ensure(
|
||||
all(is_equal(a, b) for a, b in zip(schema_a.schema, schema_b.schema))
|
||||
and is_equal(schema_a.trailing_components, schema_b.trailing_components),
|
||||
"Schemas differ:\nA = {!r}\nB = {!r}".format(schema_a, schema_b),
|
||||
)
|
||||
|
||||
def test_always_never_matching(self):
|
||||
paths = [
|
||||
[],
|
||||
[0],
|
||||
[H_(0)],
|
||||
[44],
|
||||
[H_(44)],
|
||||
[H_(44), H_(0), H_(0), 0, 0],
|
||||
[H_(44), H_(0), H_(0), H_(0), H_(0)],
|
||||
[H_(44), H_(0), H_(0), H_(0), H_(0)] * 10,
|
||||
]
|
||||
for path in paths:
|
||||
self.assertMatch(SCHEMA_ANY_PATH, path)
|
||||
self.assertMismatch(SCHEMA_NO_MATCH, path)
|
||||
|
||||
def test_pattern_fixed(self):
|
||||
pattern = "m/44'/0'/0'/0/0"
|
||||
schema = PathSchema(pattern, 0)
|
||||
|
||||
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
|
||||
|
||||
paths = [
|
||||
[],
|
||||
[0],
|
||||
[H_(0)],
|
||||
[44],
|
||||
[H_(44)],
|
||||
[44, 0, 0, 0, 0],
|
||||
[H_(44), H_(0), H_(0), H_(0), H_(0)],
|
||||
[H_(44), H_(0), H_(0), H_(0), H_(0)] * 10,
|
||||
]
|
||||
for path in paths:
|
||||
self.assertMismatch(schema, path)
|
||||
|
||||
def test_ranges_sets(self):
|
||||
pattern_ranges = "m/44'/[100-109]'/[0-20]"
|
||||
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
|
||||
schema_ranges = PathSchema(pattern_ranges, 0)
|
||||
schema_sets = PathSchema(pattern_sets, 0)
|
||||
|
||||
paths_good = [
|
||||
[H_(44), H_(100), 0],
|
||||
[H_(44), H_(100), 10],
|
||||
[H_(44), H_(100), 20],
|
||||
[H_(44), H_(105), 0],
|
||||
[H_(44), H_(105), 10],
|
||||
[H_(44), H_(105), 20],
|
||||
[H_(44), H_(109), 0],
|
||||
[H_(44), H_(109), 10],
|
||||
[H_(44), H_(109), 20],
|
||||
]
|
||||
for path in paths_good:
|
||||
self.assertMatch(schema_ranges, path)
|
||||
self.assertMatch(schema_sets, path)
|
||||
|
||||
paths_bad = [
|
||||
[H_(44), H_(100)],
|
||||
[H_(44), H_(100), 0, 0],
|
||||
[H_(44), 100, 0],
|
||||
[H_(44), 100, H_(0)],
|
||||
[H_(44), H_(99), 0],
|
||||
[H_(44), H_(110), 0],
|
||||
[H_(44), H_(100), 21],
|
||||
]
|
||||
for path in paths_bad:
|
||||
self.assertMismatch(schema_ranges, path)
|
||||
self.assertMismatch(schema_sets, path)
|
||||
|
||||
self.assertMatch(schema_ranges, [H_(44), H_(104), 19])
|
||||
self.assertMismatch(schema_sets, [H_(44), H_(104), 19])
|
||||
|
||||
def test_brackets(self):
|
||||
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
|
||||
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
|
||||
schema_a = PathSchema(pattern_a, 0)
|
||||
schema_b = PathSchema(pattern_b, 0)
|
||||
self.assertEqualSchema(schema_a, schema_b)
|
||||
|
||||
def test_wildcard(self):
|
||||
pattern = "m/44'/0'/*"
|
||||
schema = PathSchema(pattern, 0)
|
||||
|
||||
paths_good = [
|
||||
[H_(44), H_(0)],
|
||||
[H_(44), H_(0), 0],
|
||||
[H_(44), H_(0), 0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
]
|
||||
for path in paths_good:
|
||||
self.assertMatch(schema, path)
|
||||
|
||||
paths_bad = [
|
||||
[H_(44)],
|
||||
[H_(44), H_(0), H_(0)],
|
||||
[H_(44), H_(0), 0, 1, 2, 3, 4, 5, 6, 7, H_(8)],
|
||||
]
|
||||
for path in paths_bad:
|
||||
self.assertMismatch(schema, path)
|
||||
|
||||
def test_substitutes(self):
|
||||
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
|
||||
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
|
||||
schema_sub = PathSchema(pattern_sub, slip44_id=19)
|
||||
# use wrong slip44 id to ensure it doesn't affect anything
|
||||
schema_plain = PathSchema(pattern_plain, slip44_id=0)
|
||||
|
||||
self.assertEqualSchema(schema_sub, schema_plain)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user