From f3db4f2dd35efc7835ea84ddf17b3d96bbe00f00 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 19 Mar 2021 16:32:55 +0100 Subject: [PATCH] refactor(core): defragment PathSchema memory usage --- core/src/apps/bitcoin/keychain.py | 11 ++- core/src/apps/cardano/helpers/paths.py | 8 +- core/src/apps/common/keychain.py | 3 +- core/src/apps/common/paths.py | 97 +++++++++++++++++++++---- core/src/apps/ethereum/keychain.py | 8 +- core/tests/test_apps.common.keychain.py | 11 ++- core/tests/test_apps.common.paths.py | 36 +++++++-- 7 files changed, 135 insertions(+), 39 deletions(-) diff --git a/core/src/apps/bitcoin/keychain.py b/core/src/apps/bitcoin/keychain.py index 3b454fe95a..80debf0bf0 100644 --- a/core/src/apps/bitcoin/keychain.py +++ b/core/src/apps/bitcoin/keychain.py @@ -1,3 +1,5 @@ +import gc + from trezor import wire from trezor.messages import InputScriptType as I @@ -113,7 +115,7 @@ def validate_path_against_script_type( patterns.append(PATTERN_GREENADDRESS_B) return any( - PathSchema(pattern, coin.slip44).match(address_n) for pattern in patterns + PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns ) @@ -151,15 +153,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]: ) ) - schemas = [PathSchema(pattern, coin.slip44) for pattern in patterns] + schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns] # some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths # we can allow spending these coins from Bitcoin paths if the coin has # implemented strong replay protection via SIGHASH_FORKID if coin.fork_id is not None: - schemas.extend(PathSchema(pattern, 0) for pattern in patterns) + schemas.extend(PathSchema.parse(pattern, 0) for pattern in patterns) - return schemas + gc.collect() + return [schema.copy() for schema in schemas] def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: diff --git a/core/src/apps/cardano/helpers/paths.py b/core/src/apps/cardano/helpers/paths.py index c2d1744c03..453cec6538 100644 --- a/core/src/apps/cardano/helpers/paths.py +++ b/core/src/apps/cardano/helpers/paths.py @@ -9,11 +9,11 @@ BYRON_ROOT = [44 | HARDENED, SLIP44_ID | HARDENED] SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED] # fmt: off -SCHEMA_PUBKEY = PathSchema("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID) -SCHEMA_ADDRESS = PathSchema("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID) +SCHEMA_PUBKEY = PathSchema.parse("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID) +SCHEMA_ADDRESS = PathSchema.parse("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID) # staking is only allowed on Shelley paths with suffix /2/0 -SCHEMA_STAKING = PathSchema("m/1852'/coin_type'/account'/2/0", SLIP44_ID) -SCHEMA_STAKING_ANY_ACCOUNT = PathSchema("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID) +SCHEMA_STAKING = PathSchema.parse("m/1852'/coin_type'/account'/2/0", SLIP44_ID) +SCHEMA_STAKING_ANY_ACCOUNT = PathSchema.parse("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID) # fmt: on # the maximum allowed change address. this should be large enough for normal diff --git a/core/src/apps/common/keychain.py b/core/src/apps/common/keychain.py index 32ab472f8d..ec61625743 100644 --- a/core/src/apps/common/keychain.py +++ b/core/src/apps/common/keychain.py @@ -197,7 +197,8 @@ def with_slip44_keychain( schemas = [] for pattern in patterns: - schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids)) + schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids)) + schemas = [s.copy() for s in schemas] def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: diff --git a/core/src/apps/common/paths.py b/core/src/apps/common/paths.py index 10abc8a967..197db19e1e 100644 --- a/core/src/apps/common/paths.py +++ b/core/src/apps/common/paths.py @@ -104,7 +104,63 @@ class PathSchema: "**": Interval(0, 0xFFFF_FFFF), } - def __init__(self, pattern: str, slip44_id: int | Iterable[int]) -> None: + _EMPTY_TUPLE = () + + @staticmethod + def _parse_hardened(s: str) -> int: + return int(s) | HARDENED + + @staticmethod + def _copy_container(container: Container[int]) -> Container[int]: + if isinstance(container, Interval): + return Interval(container.min, container.max) + if isinstance(container, set): + return set(container) + if isinstance(container, tuple): + return container[:] + raise RuntimeError("Unsupported container for copy") + + def __init__( + self, + schema: list[Container[int]], + trailing_components: Container[int] = (), + compact: bool = False, + ) -> None: + """Create a new PathSchema from a list of containers and trailing components. + + Mainly for internal use in `PathSchema.parse`, which is the method you should + be using. + + Can be used to create a schema manually without parsing a path string: + + >>> SCHEMA_MINE = PathSchema([ + >>> (44 | HARDENED,), + >>> (0 | HARDENED,), + >>> Interval(0 | HARDENED, 10 | HARDENED), + >>> ], + >>> trailing_components=Interval(0, 0xFFFF_FFFF), + >>> ) + + Setting `compact=True` creates a compact copy of the provided components, so + as to prevent memory fragmentation. + """ + if compact: + self.schema: list[Container[int]] = [self._EMPTY_TUPLE] * len(schema) + for i in range(len(schema)): + self.schema[i] = self._copy_container(schema[i]) + self.trailing_components = self._copy_container(trailing_components) + + else: + self.schema = schema + self.trailing_components = trailing_components + + @classmethod + def parse(cls, pattern: str, slip44_id: int | Iterable[int]) -> "PathSchema": + """Parse a path schema string into a PathSchema instance. + + The parsing process trashes the memory layout, so at the end a compact-allocated + copy of the resulting structures is returned. + """ if not pattern.startswith("m/"): raise ValueError # unsupported path template components = pattern[2:].split("/") @@ -112,24 +168,24 @@ class PathSchema: if isinstance(slip44_id, int): slip44_id = (slip44_id,) - self.schema: list[Container[int]] = [] - self.trailing_components: Container[int] = () + schema: list[Container[int]] = [] + trailing_components: Container[int] = () for component in components: - if component in self.WILDCARD_RANGES: - if len(self.schema) != len(components) - 1: - # every component should have resulted in extending self.schema - # so if self.schema does not have the appropriate length (yet), + if component in cls.WILDCARD_RANGES: + if len(schema) != len(components) - 1: + # every component should have resulted in extending schema + # so if schema does not have the appropriate length (yet), # the asterisk is not the last item raise ValueError # asterisk is not last item of pattern - self.trailing_components = self.WILDCARD_RANGES[component] + trailing_components = cls.WILDCARD_RANGES[component] break # figure out if the component is hardened if component[-1] == "'": component = component[:-1] - parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731 + parse: Callable[[Any], int] = cls._parse_hardened else: parse = int @@ -138,24 +194,37 @@ class PathSchema: component = component[1:-1] # optionally replace a keyword - component = self.REPLACEMENTS.get(component, component) + component = cls.REPLACEMENTS.get(component, component) if "-" in component: # parse as a range a, b = [parse(s) for s in component.split("-", 1)] - self.schema.append(Interval(a, b)) + schema.append(Interval(a, b)) elif "," in component: # parse as a list of values - self.schema.append(set(parse(s) for s in component.split(","))) + schema.append(set(parse(s) for s in component.split(","))) elif component == "coin_type": # substitute SLIP-44 ids - self.schema.append(set(parse(s) for s in slip44_id)) + schema.append(set(parse(s) for s in slip44_id)) else: # plain constant - self.schema.append((parse(component),)) + schema.append((parse(component),)) + + return cls(schema, trailing_components, compact=True) + + def copy(self) -> "PathSchema": + """Create a compact copy of the schema. + + Useful when creating multiple schemas in a row. The following code ensures + that the set of schemas is allocated in a contiguous block of memory: + + >>> some_schemas = make_multiple_schemas() + >>> some_schemas = [s.copy() for s in some_schemas] + """ + return PathSchema(self.schema, self.trailing_components, compact=True) def match(self, path: Bip32Path) -> bool: # The path must not be _shorter_ than schema. It may be longer. diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py index 9a1dd26172..5e79fecf3c 100644 --- a/core/src/apps/ethereum/keychain.py +++ b/core/src/apps/ethereum/keychain.py @@ -41,7 +41,8 @@ def _schemas_from_address_n( return () slip44_id = slip44_hardened - HARDENED - return (paths.PathSchema(pattern, slip44_id) for pattern in patterns) + schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns] + return [s.copy() for s in schemas] def with_keychain_from_path( @@ -79,7 +80,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]: else: slip44_id = (info.slip44,) - return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS) + schemas = [ + paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS + ] + return [s.copy() for s in schemas] def with_keychain_from_chain_id( diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index e55ed1c7ed..0ccfafa49a 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -3,9 +3,8 @@ from common import * from mock_storage import mock_storage from storage import cache -import storage.device from apps.common import safety_checks -from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened +from apps.common.paths import PATTERN_SEP5, PathSchema from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain from trezor import wire from trezor.crypto import bip39 @@ -22,8 +21,8 @@ class TestKeychain(unittest.TestCase): @mock_storage def test_verify_path(self): schemas = ( - PathSchema("m/44'/coin_type'", slip44_id=134), - PathSchema("m/44'/coin_type'", slip44_id=11), + PathSchema.parse("m/44'/coin_type'", slip44_id=134), + PathSchema.parse("m/44'/coin_type'", slip44_id=11), ) keychain = Keychain(b"", "secp256k1", schemas) @@ -49,7 +48,7 @@ class TestKeychain(unittest.TestCase): keychain.verify_path(path) def test_verify_path_special_ed25519(self): - schema = PathSchema("m/44'/coin_type'/*", slip44_id=134) + schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134) k = Keychain(b"", "ed25519-keccak", [schema]) # OK case @@ -74,7 +73,7 @@ class TestKeychain(unittest.TestCase): seed = bip39.seed(" ".join(["all"] * 12), "") cache.set(cache.APP_COMMON_SEED, seed) - schema = PathSchema("m/44'/1'", 0) + schema = PathSchema.parse("m/44'/1'", 0) keychain = await_result( get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema]) ) diff --git a/core/tests/test_apps.common.paths.py b/core/tests/test_apps.common.paths.py index 0b1a2df19d..dde5149dd0 100644 --- a/core/tests/test_apps.common.paths.py +++ b/core/tests/test_apps.common.paths.py @@ -68,7 +68,7 @@ class TestPathSchemas(unittest.TestCase): def test_pattern_fixed(self): pattern = "m/44'/0'/0'/0/0" - schema = PathSchema(pattern, 0) + schema = PathSchema.parse(pattern, 0) self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0]) @@ -88,8 +88,8 @@ class TestPathSchemas(unittest.TestCase): def test_ranges_sets(self): pattern_ranges = "m/44'/[100-109]'/[0-20]" pattern_sets = "m/44'/[100,105,109]'/[0,10,20]" - schema_ranges = PathSchema(pattern_ranges, 0) - schema_sets = PathSchema(pattern_sets, 0) + schema_ranges = PathSchema.parse(pattern_ranges, 0) + schema_sets = PathSchema.parse(pattern_sets, 0) paths_good = [ [H_(44), H_(100), 0], @@ -125,13 +125,13 @@ class TestPathSchemas(unittest.TestCase): def test_brackets(self): pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]" pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2" - schema_a = PathSchema(pattern_a, 0) - schema_b = PathSchema(pattern_b, 0) + schema_a = PathSchema.parse(pattern_a, 0) + schema_b = PathSchema.parse(pattern_b, 0) self.assertEqualSchema(schema_a, schema_b) def test_wildcard(self): pattern = "m/44'/0'/*" - schema = PathSchema(pattern, 0) + schema = PathSchema.parse(pattern, 0) paths_good = [ [H_(44), H_(0)], @@ -152,12 +152,32 @@ class TestPathSchemas(unittest.TestCase): def test_substitutes(self): pattern_sub = "m/44'/coin_type'/account'/change/address_index" pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000" - schema_sub = PathSchema(pattern_sub, slip44_id=19) + schema_sub = PathSchema.parse(pattern_sub, slip44_id=19) # use wrong slip44 id to ensure it doesn't affect anything - schema_plain = PathSchema(pattern_plain, slip44_id=0) + schema_plain = PathSchema.parse(pattern_plain, slip44_id=0) self.assertEqualSchema(schema_sub, schema_plain) + def test_copy(self): + schema_normal = PathSchema.parse("m/44'/0'/0'/0/0", slip44_id=0) + self.assertEqualSchema(schema_normal, schema_normal.copy()) + + schema_wildcard = PathSchema.parse("m/44'/0'/0'/0/**", slip44_id=0) + self.assertEqualSchema(schema_wildcard, schema_wildcard.copy()) + + def test_parse(self): + schema_parsed = PathSchema.parse("m/44'/0-5'/0,1,2'/0/**", slip44_id=0) + schema_manual = PathSchema( + [ + (H_(44),), + Interval(H_(0), H_(5)), + set((H_(0), H_(1), H_(2))), + (0,), + ], + trailing_components=Interval(0, 0xFFFF_FFFF), + ) + self.assertEqualSchema(schema_manual, schema_parsed) + if __name__ == "__main__": unittest.main()