mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-17 19:00:58 +00:00
refactor(core): defragment PathSchema memory usage
This commit is contained in:
parent
e5a481ded5
commit
f3db4f2dd3
@ -1,3 +1,5 @@
|
||||
import gc
|
||||
|
||||
from trezor import wire
|
||||
from trezor.messages import InputScriptType as I
|
||||
|
||||
@ -113,7 +115,7 @@ def validate_path_against_script_type(
|
||||
patterns.append(PATTERN_GREENADDRESS_B)
|
||||
|
||||
return any(
|
||||
PathSchema(pattern, coin.slip44).match(address_n) for pattern in patterns
|
||||
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
|
||||
)
|
||||
|
||||
|
||||
@ -151,15 +153,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
|
||||
)
|
||||
)
|
||||
|
||||
schemas = [PathSchema(pattern, coin.slip44) for pattern in patterns]
|
||||
schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
|
||||
|
||||
# some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths
|
||||
# we can allow spending these coins from Bitcoin paths if the coin has
|
||||
# implemented strong replay protection via SIGHASH_FORKID
|
||||
if coin.fork_id is not None:
|
||||
schemas.extend(PathSchema(pattern, 0) for pattern in patterns)
|
||||
schemas.extend(PathSchema.parse(pattern, 0) for pattern in patterns)
|
||||
|
||||
return schemas
|
||||
gc.collect()
|
||||
return [schema.copy() for schema in schemas]
|
||||
|
||||
|
||||
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
|
||||
|
@ -9,11 +9,11 @@ BYRON_ROOT = [44 | HARDENED, SLIP44_ID | HARDENED]
|
||||
SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED]
|
||||
|
||||
# fmt: off
|
||||
SCHEMA_PUBKEY = PathSchema("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
|
||||
SCHEMA_ADDRESS = PathSchema("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
|
||||
SCHEMA_PUBKEY = PathSchema.parse("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
|
||||
SCHEMA_ADDRESS = PathSchema.parse("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
|
||||
# staking is only allowed on Shelley paths with suffix /2/0
|
||||
SCHEMA_STAKING = PathSchema("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
|
||||
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
|
||||
SCHEMA_STAKING = PathSchema.parse("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
|
||||
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema.parse("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
|
||||
# fmt: on
|
||||
|
||||
# the maximum allowed change address. this should be large enough for normal
|
||||
|
@ -197,7 +197,8 @@ def with_slip44_keychain(
|
||||
|
||||
schemas = []
|
||||
for pattern in patterns:
|
||||
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids))
|
||||
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
|
||||
schemas = [s.copy() for s in schemas]
|
||||
|
||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
|
@ -104,7 +104,63 @@ class PathSchema:
|
||||
"**": Interval(0, 0xFFFF_FFFF),
|
||||
}
|
||||
|
||||
def __init__(self, pattern: str, slip44_id: int | Iterable[int]) -> None:
|
||||
_EMPTY_TUPLE = ()
|
||||
|
||||
@staticmethod
|
||||
def _parse_hardened(s: str) -> int:
|
||||
return int(s) | HARDENED
|
||||
|
||||
@staticmethod
|
||||
def _copy_container(container: Container[int]) -> Container[int]:
|
||||
if isinstance(container, Interval):
|
||||
return Interval(container.min, container.max)
|
||||
if isinstance(container, set):
|
||||
return set(container)
|
||||
if isinstance(container, tuple):
|
||||
return container[:]
|
||||
raise RuntimeError("Unsupported container for copy")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: list[Container[int]],
|
||||
trailing_components: Container[int] = (),
|
||||
compact: bool = False,
|
||||
) -> None:
|
||||
"""Create a new PathSchema from a list of containers and trailing components.
|
||||
|
||||
Mainly for internal use in `PathSchema.parse`, which is the method you should
|
||||
be using.
|
||||
|
||||
Can be used to create a schema manually without parsing a path string:
|
||||
|
||||
>>> SCHEMA_MINE = PathSchema([
|
||||
>>> (44 | HARDENED,),
|
||||
>>> (0 | HARDENED,),
|
||||
>>> Interval(0 | HARDENED, 10 | HARDENED),
|
||||
>>> ],
|
||||
>>> trailing_components=Interval(0, 0xFFFF_FFFF),
|
||||
>>> )
|
||||
|
||||
Setting `compact=True` creates a compact copy of the provided components, so
|
||||
as to prevent memory fragmentation.
|
||||
"""
|
||||
if compact:
|
||||
self.schema: list[Container[int]] = [self._EMPTY_TUPLE] * len(schema)
|
||||
for i in range(len(schema)):
|
||||
self.schema[i] = self._copy_container(schema[i])
|
||||
self.trailing_components = self._copy_container(trailing_components)
|
||||
|
||||
else:
|
||||
self.schema = schema
|
||||
self.trailing_components = trailing_components
|
||||
|
||||
@classmethod
|
||||
def parse(cls, pattern: str, slip44_id: int | Iterable[int]) -> "PathSchema":
|
||||
"""Parse a path schema string into a PathSchema instance.
|
||||
|
||||
The parsing process trashes the memory layout, so at the end a compact-allocated
|
||||
copy of the resulting structures is returned.
|
||||
"""
|
||||
if not pattern.startswith("m/"):
|
||||
raise ValueError # unsupported path template
|
||||
components = pattern[2:].split("/")
|
||||
@ -112,24 +168,24 @@ class PathSchema:
|
||||
if isinstance(slip44_id, int):
|
||||
slip44_id = (slip44_id,)
|
||||
|
||||
self.schema: list[Container[int]] = []
|
||||
self.trailing_components: Container[int] = ()
|
||||
schema: list[Container[int]] = []
|
||||
trailing_components: Container[int] = ()
|
||||
|
||||
for component in components:
|
||||
if component in self.WILDCARD_RANGES:
|
||||
if len(self.schema) != len(components) - 1:
|
||||
# every component should have resulted in extending self.schema
|
||||
# so if self.schema does not have the appropriate length (yet),
|
||||
if component in cls.WILDCARD_RANGES:
|
||||
if len(schema) != len(components) - 1:
|
||||
# every component should have resulted in extending schema
|
||||
# so if schema does not have the appropriate length (yet),
|
||||
# the asterisk is not the last item
|
||||
raise ValueError # asterisk is not last item of pattern
|
||||
|
||||
self.trailing_components = self.WILDCARD_RANGES[component]
|
||||
trailing_components = cls.WILDCARD_RANGES[component]
|
||||
break
|
||||
|
||||
# figure out if the component is hardened
|
||||
if component[-1] == "'":
|
||||
component = component[:-1]
|
||||
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731
|
||||
parse: Callable[[Any], int] = cls._parse_hardened
|
||||
else:
|
||||
parse = int
|
||||
|
||||
@ -138,24 +194,37 @@ class PathSchema:
|
||||
component = component[1:-1]
|
||||
|
||||
# optionally replace a keyword
|
||||
component = self.REPLACEMENTS.get(component, component)
|
||||
component = cls.REPLACEMENTS.get(component, component)
|
||||
|
||||
if "-" in component:
|
||||
# parse as a range
|
||||
a, b = [parse(s) for s in component.split("-", 1)]
|
||||
self.schema.append(Interval(a, b))
|
||||
schema.append(Interval(a, b))
|
||||
|
||||
elif "," in component:
|
||||
# parse as a list of values
|
||||
self.schema.append(set(parse(s) for s in component.split(",")))
|
||||
schema.append(set(parse(s) for s in component.split(",")))
|
||||
|
||||
elif component == "coin_type":
|
||||
# substitute SLIP-44 ids
|
||||
self.schema.append(set(parse(s) for s in slip44_id))
|
||||
schema.append(set(parse(s) for s in slip44_id))
|
||||
|
||||
else:
|
||||
# plain constant
|
||||
self.schema.append((parse(component),))
|
||||
schema.append((parse(component),))
|
||||
|
||||
return cls(schema, trailing_components, compact=True)
|
||||
|
||||
def copy(self) -> "PathSchema":
|
||||
"""Create a compact copy of the schema.
|
||||
|
||||
Useful when creating multiple schemas in a row. The following code ensures
|
||||
that the set of schemas is allocated in a contiguous block of memory:
|
||||
|
||||
>>> some_schemas = make_multiple_schemas()
|
||||
>>> some_schemas = [s.copy() for s in some_schemas]
|
||||
"""
|
||||
return PathSchema(self.schema, self.trailing_components, compact=True)
|
||||
|
||||
def match(self, path: Bip32Path) -> bool:
|
||||
# The path must not be _shorter_ than schema. It may be longer.
|
||||
|
@ -41,7 +41,8 @@ def _schemas_from_address_n(
|
||||
return ()
|
||||
|
||||
slip44_id = slip44_hardened - HARDENED
|
||||
return (paths.PathSchema(pattern, slip44_id) for pattern in patterns)
|
||||
schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
|
||||
return [s.copy() for s in schemas]
|
||||
|
||||
|
||||
def with_keychain_from_path(
|
||||
@ -79,7 +80,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
|
||||
else:
|
||||
slip44_id = (info.slip44,)
|
||||
|
||||
return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS)
|
||||
schemas = [
|
||||
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
|
||||
]
|
||||
return [s.copy() for s in schemas]
|
||||
|
||||
|
||||
def with_keychain_from_chain_id(
|
||||
|
@ -3,9 +3,8 @@ from common import *
|
||||
from mock_storage import mock_storage
|
||||
|
||||
from storage import cache
|
||||
import storage.device
|
||||
from apps.common import safety_checks
|
||||
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened
|
||||
from apps.common.paths import PATTERN_SEP5, PathSchema
|
||||
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip39
|
||||
@ -22,8 +21,8 @@ class TestKeychain(unittest.TestCase):
|
||||
@mock_storage
|
||||
def test_verify_path(self):
|
||||
schemas = (
|
||||
PathSchema("m/44'/coin_type'", slip44_id=134),
|
||||
PathSchema("m/44'/coin_type'", slip44_id=11),
|
||||
PathSchema.parse("m/44'/coin_type'", slip44_id=134),
|
||||
PathSchema.parse("m/44'/coin_type'", slip44_id=11),
|
||||
)
|
||||
keychain = Keychain(b"", "secp256k1", schemas)
|
||||
|
||||
@ -49,7 +48,7 @@ class TestKeychain(unittest.TestCase):
|
||||
keychain.verify_path(path)
|
||||
|
||||
def test_verify_path_special_ed25519(self):
|
||||
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134)
|
||||
schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134)
|
||||
k = Keychain(b"", "ed25519-keccak", [schema])
|
||||
|
||||
# OK case
|
||||
@ -74,7 +73,7 @@ class TestKeychain(unittest.TestCase):
|
||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||
cache.set(cache.APP_COMMON_SEED, seed)
|
||||
|
||||
schema = PathSchema("m/44'/1'", 0)
|
||||
schema = PathSchema.parse("m/44'/1'", 0)
|
||||
keychain = await_result(
|
||||
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
|
||||
)
|
||||
|
@ -68,7 +68,7 @@ class TestPathSchemas(unittest.TestCase):
|
||||
|
||||
def test_pattern_fixed(self):
|
||||
pattern = "m/44'/0'/0'/0/0"
|
||||
schema = PathSchema(pattern, 0)
|
||||
schema = PathSchema.parse(pattern, 0)
|
||||
|
||||
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
|
||||
|
||||
@ -88,8 +88,8 @@ class TestPathSchemas(unittest.TestCase):
|
||||
def test_ranges_sets(self):
|
||||
pattern_ranges = "m/44'/[100-109]'/[0-20]"
|
||||
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
|
||||
schema_ranges = PathSchema(pattern_ranges, 0)
|
||||
schema_sets = PathSchema(pattern_sets, 0)
|
||||
schema_ranges = PathSchema.parse(pattern_ranges, 0)
|
||||
schema_sets = PathSchema.parse(pattern_sets, 0)
|
||||
|
||||
paths_good = [
|
||||
[H_(44), H_(100), 0],
|
||||
@ -125,13 +125,13 @@ class TestPathSchemas(unittest.TestCase):
|
||||
def test_brackets(self):
|
||||
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
|
||||
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
|
||||
schema_a = PathSchema(pattern_a, 0)
|
||||
schema_b = PathSchema(pattern_b, 0)
|
||||
schema_a = PathSchema.parse(pattern_a, 0)
|
||||
schema_b = PathSchema.parse(pattern_b, 0)
|
||||
self.assertEqualSchema(schema_a, schema_b)
|
||||
|
||||
def test_wildcard(self):
|
||||
pattern = "m/44'/0'/*"
|
||||
schema = PathSchema(pattern, 0)
|
||||
schema = PathSchema.parse(pattern, 0)
|
||||
|
||||
paths_good = [
|
||||
[H_(44), H_(0)],
|
||||
@ -152,12 +152,32 @@ class TestPathSchemas(unittest.TestCase):
|
||||
def test_substitutes(self):
|
||||
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
|
||||
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
|
||||
schema_sub = PathSchema(pattern_sub, slip44_id=19)
|
||||
schema_sub = PathSchema.parse(pattern_sub, slip44_id=19)
|
||||
# use wrong slip44 id to ensure it doesn't affect anything
|
||||
schema_plain = PathSchema(pattern_plain, slip44_id=0)
|
||||
schema_plain = PathSchema.parse(pattern_plain, slip44_id=0)
|
||||
|
||||
self.assertEqualSchema(schema_sub, schema_plain)
|
||||
|
||||
def test_copy(self):
|
||||
schema_normal = PathSchema.parse("m/44'/0'/0'/0/0", slip44_id=0)
|
||||
self.assertEqualSchema(schema_normal, schema_normal.copy())
|
||||
|
||||
schema_wildcard = PathSchema.parse("m/44'/0'/0'/0/**", slip44_id=0)
|
||||
self.assertEqualSchema(schema_wildcard, schema_wildcard.copy())
|
||||
|
||||
def test_parse(self):
|
||||
schema_parsed = PathSchema.parse("m/44'/0-5'/0,1,2'/0/**", slip44_id=0)
|
||||
schema_manual = PathSchema(
|
||||
[
|
||||
(H_(44),),
|
||||
Interval(H_(0), H_(5)),
|
||||
set((H_(0), H_(1), H_(2))),
|
||||
(0,),
|
||||
],
|
||||
trailing_components=Interval(0, 0xFFFF_FFFF),
|
||||
)
|
||||
self.assertEqualSchema(schema_manual, schema_parsed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user