1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-19 12:58:13 +00:00

refactor(core): defragment PathSchema memory usage

This commit is contained in:
matejcik 2021-03-19 16:32:55 +01:00 committed by matejcik
parent e5a481ded5
commit f3db4f2dd3
7 changed files with 135 additions and 39 deletions

View File

@ -1,3 +1,5 @@
import gc
from trezor import wire from trezor import wire
from trezor.messages import InputScriptType as I from trezor.messages import InputScriptType as I
@ -113,7 +115,7 @@ def validate_path_against_script_type(
patterns.append(PATTERN_GREENADDRESS_B) patterns.append(PATTERN_GREENADDRESS_B)
return any( return any(
PathSchema(pattern, coin.slip44).match(address_n) for pattern in patterns PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
) )
@ -151,15 +153,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
) )
) )
schemas = [PathSchema(pattern, coin.slip44) for pattern in patterns] schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
# some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths # some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths
# we can allow spending these coins from Bitcoin paths if the coin has # we can allow spending these coins from Bitcoin paths if the coin has
# implemented strong replay protection via SIGHASH_FORKID # implemented strong replay protection via SIGHASH_FORKID
if coin.fork_id is not None: if coin.fork_id is not None:
schemas.extend(PathSchema(pattern, 0) for pattern in patterns) schemas.extend(PathSchema.parse(pattern, 0) for pattern in patterns)
return schemas gc.collect()
return [schema.copy() for schema in schemas]
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:

View File

@ -9,11 +9,11 @@ BYRON_ROOT = [44 | HARDENED, SLIP44_ID | HARDENED]
SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED] SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED]
# fmt: off # fmt: off
SCHEMA_PUBKEY = PathSchema("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID) SCHEMA_PUBKEY = PathSchema.parse("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
SCHEMA_ADDRESS = PathSchema("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID) SCHEMA_ADDRESS = PathSchema.parse("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
# staking is only allowed on Shelley paths with suffix /2/0 # staking is only allowed on Shelley paths with suffix /2/0
SCHEMA_STAKING = PathSchema("m/1852'/coin_type'/account'/2/0", SLIP44_ID) SCHEMA_STAKING = PathSchema.parse("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID) SCHEMA_STAKING_ANY_ACCOUNT = PathSchema.parse("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
# fmt: on # fmt: on
# the maximum allowed change address. this should be large enough for normal # the maximum allowed change address. this should be large enough for normal

View File

@ -197,7 +197,8 @@ def with_slip44_keychain(
schemas = [] schemas = []
for pattern in patterns: for pattern in patterns:
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids)) schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:

View File

@ -104,7 +104,63 @@ class PathSchema:
"**": Interval(0, 0xFFFF_FFFF), "**": Interval(0, 0xFFFF_FFFF),
} }
def __init__(self, pattern: str, slip44_id: int | Iterable[int]) -> None: _EMPTY_TUPLE = ()
@staticmethod
def _parse_hardened(s: str) -> int:
return int(s) | HARDENED
@staticmethod
def _copy_container(container: Container[int]) -> Container[int]:
if isinstance(container, Interval):
return Interval(container.min, container.max)
if isinstance(container, set):
return set(container)
if isinstance(container, tuple):
return container[:]
raise RuntimeError("Unsupported container for copy")
def __init__(
self,
schema: list[Container[int]],
trailing_components: Container[int] = (),
compact: bool = False,
) -> None:
"""Create a new PathSchema from a list of containers and trailing components.
Mainly for internal use in `PathSchema.parse`, which is the method you should
be using.
Can be used to create a schema manually without parsing a path string:
>>> SCHEMA_MINE = PathSchema([
>>> (44 | HARDENED,),
>>> (0 | HARDENED,),
>>> Interval(0 | HARDENED, 10 | HARDENED),
>>> ],
>>> trailing_components=Interval(0, 0xFFFF_FFFF),
>>> )
Setting `compact=True` creates a compact copy of the provided components, so
as to prevent memory fragmentation.
"""
if compact:
self.schema: list[Container[int]] = [self._EMPTY_TUPLE] * len(schema)
for i in range(len(schema)):
self.schema[i] = self._copy_container(schema[i])
self.trailing_components = self._copy_container(trailing_components)
else:
self.schema = schema
self.trailing_components = trailing_components
@classmethod
def parse(cls, pattern: str, slip44_id: int | Iterable[int]) -> "PathSchema":
"""Parse a path schema string into a PathSchema instance.
The parsing process trashes the memory layout, so at the end a compact-allocated
copy of the resulting structures is returned.
"""
if not pattern.startswith("m/"): if not pattern.startswith("m/"):
raise ValueError # unsupported path template raise ValueError # unsupported path template
components = pattern[2:].split("/") components = pattern[2:].split("/")
@ -112,24 +168,24 @@ class PathSchema:
if isinstance(slip44_id, int): if isinstance(slip44_id, int):
slip44_id = (slip44_id,) slip44_id = (slip44_id,)
self.schema: list[Container[int]] = [] schema: list[Container[int]] = []
self.trailing_components: Container[int] = () trailing_components: Container[int] = ()
for component in components: for component in components:
if component in self.WILDCARD_RANGES: if component in cls.WILDCARD_RANGES:
if len(self.schema) != len(components) - 1: if len(schema) != len(components) - 1:
# every component should have resulted in extending self.schema # every component should have resulted in extending schema
# so if self.schema does not have the appropriate length (yet), # so if schema does not have the appropriate length (yet),
# the asterisk is not the last item # the asterisk is not the last item
raise ValueError # asterisk is not last item of pattern raise ValueError # asterisk is not last item of pattern
self.trailing_components = self.WILDCARD_RANGES[component] trailing_components = cls.WILDCARD_RANGES[component]
break break
# figure out if the component is hardened # figure out if the component is hardened
if component[-1] == "'": if component[-1] == "'":
component = component[:-1] component = component[:-1]
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731 parse: Callable[[Any], int] = cls._parse_hardened
else: else:
parse = int parse = int
@ -138,24 +194,37 @@ class PathSchema:
component = component[1:-1] component = component[1:-1]
# optionally replace a keyword # optionally replace a keyword
component = self.REPLACEMENTS.get(component, component) component = cls.REPLACEMENTS.get(component, component)
if "-" in component: if "-" in component:
# parse as a range # parse as a range
a, b = [parse(s) for s in component.split("-", 1)] a, b = [parse(s) for s in component.split("-", 1)]
self.schema.append(Interval(a, b)) schema.append(Interval(a, b))
elif "," in component: elif "," in component:
# parse as a list of values # parse as a list of values
self.schema.append(set(parse(s) for s in component.split(","))) schema.append(set(parse(s) for s in component.split(",")))
elif component == "coin_type": elif component == "coin_type":
# substitute SLIP-44 ids # substitute SLIP-44 ids
self.schema.append(set(parse(s) for s in slip44_id)) schema.append(set(parse(s) for s in slip44_id))
else: else:
# plain constant # plain constant
self.schema.append((parse(component),)) schema.append((parse(component),))
return cls(schema, trailing_components, compact=True)
def copy(self) -> "PathSchema":
"""Create a compact copy of the schema.
Useful when creating multiple schemas in a row. The following code ensures
that the set of schemas is allocated in a contiguous block of memory:
>>> some_schemas = make_multiple_schemas()
>>> some_schemas = [s.copy() for s in some_schemas]
"""
return PathSchema(self.schema, self.trailing_components, compact=True)
def match(self, path: Bip32Path) -> bool: def match(self, path: Bip32Path) -> bool:
# The path must not be _shorter_ than schema. It may be longer. # The path must not be _shorter_ than schema. It may be longer.

View File

@ -41,7 +41,8 @@ def _schemas_from_address_n(
return () return ()
slip44_id = slip44_hardened - HARDENED slip44_id = slip44_hardened - HARDENED
return (paths.PathSchema(pattern, slip44_id) for pattern in patterns) schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
return [s.copy() for s in schemas]
def with_keychain_from_path( def with_keychain_from_path(
@ -79,7 +80,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
else: else:
slip44_id = (info.slip44,) slip44_id = (info.slip44,)
return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS) schemas = [
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
]
return [s.copy() for s in schemas]
def with_keychain_from_chain_id( def with_keychain_from_chain_id(

View File

@ -3,9 +3,8 @@ from common import *
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache from storage import cache
import storage.device
from apps.common import safety_checks from apps.common import safety_checks
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened from apps.common.paths import PATTERN_SEP5, PathSchema
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
@ -22,8 +21,8 @@ class TestKeychain(unittest.TestCase):
@mock_storage @mock_storage
def test_verify_path(self): def test_verify_path(self):
schemas = ( schemas = (
PathSchema("m/44'/coin_type'", slip44_id=134), PathSchema.parse("m/44'/coin_type'", slip44_id=134),
PathSchema("m/44'/coin_type'", slip44_id=11), PathSchema.parse("m/44'/coin_type'", slip44_id=11),
) )
keychain = Keychain(b"", "secp256k1", schemas) keychain = Keychain(b"", "secp256k1", schemas)
@ -49,7 +48,7 @@ class TestKeychain(unittest.TestCase):
keychain.verify_path(path) keychain.verify_path(path)
def test_verify_path_special_ed25519(self): def test_verify_path_special_ed25519(self):
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134) schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134)
k = Keychain(b"", "ed25519-keccak", [schema]) k = Keychain(b"", "ed25519-keccak", [schema])
# OK case # OK case
@ -74,7 +73,7 @@ class TestKeychain(unittest.TestCase):
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
schema = PathSchema("m/44'/1'", 0) schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result( keychain = await_result(
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema]) get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
) )

View File

@ -68,7 +68,7 @@ class TestPathSchemas(unittest.TestCase):
def test_pattern_fixed(self): def test_pattern_fixed(self):
pattern = "m/44'/0'/0'/0/0" pattern = "m/44'/0'/0'/0/0"
schema = PathSchema(pattern, 0) schema = PathSchema.parse(pattern, 0)
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0]) self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
@ -88,8 +88,8 @@ class TestPathSchemas(unittest.TestCase):
def test_ranges_sets(self): def test_ranges_sets(self):
pattern_ranges = "m/44'/[100-109]'/[0-20]" pattern_ranges = "m/44'/[100-109]'/[0-20]"
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]" pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
schema_ranges = PathSchema(pattern_ranges, 0) schema_ranges = PathSchema.parse(pattern_ranges, 0)
schema_sets = PathSchema(pattern_sets, 0) schema_sets = PathSchema.parse(pattern_sets, 0)
paths_good = [ paths_good = [
[H_(44), H_(100), 0], [H_(44), H_(100), 0],
@ -125,13 +125,13 @@ class TestPathSchemas(unittest.TestCase):
def test_brackets(self): def test_brackets(self):
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]" pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2" pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
schema_a = PathSchema(pattern_a, 0) schema_a = PathSchema.parse(pattern_a, 0)
schema_b = PathSchema(pattern_b, 0) schema_b = PathSchema.parse(pattern_b, 0)
self.assertEqualSchema(schema_a, schema_b) self.assertEqualSchema(schema_a, schema_b)
def test_wildcard(self): def test_wildcard(self):
pattern = "m/44'/0'/*" pattern = "m/44'/0'/*"
schema = PathSchema(pattern, 0) schema = PathSchema.parse(pattern, 0)
paths_good = [ paths_good = [
[H_(44), H_(0)], [H_(44), H_(0)],
@ -152,12 +152,32 @@ class TestPathSchemas(unittest.TestCase):
def test_substitutes(self): def test_substitutes(self):
pattern_sub = "m/44'/coin_type'/account'/change/address_index" pattern_sub = "m/44'/coin_type'/account'/change/address_index"
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000" pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
schema_sub = PathSchema(pattern_sub, slip44_id=19) schema_sub = PathSchema.parse(pattern_sub, slip44_id=19)
# use wrong slip44 id to ensure it doesn't affect anything # use wrong slip44 id to ensure it doesn't affect anything
schema_plain = PathSchema(pattern_plain, slip44_id=0) schema_plain = PathSchema.parse(pattern_plain, slip44_id=0)
self.assertEqualSchema(schema_sub, schema_plain) self.assertEqualSchema(schema_sub, schema_plain)
def test_copy(self):
schema_normal = PathSchema.parse("m/44'/0'/0'/0/0", slip44_id=0)
self.assertEqualSchema(schema_normal, schema_normal.copy())
schema_wildcard = PathSchema.parse("m/44'/0'/0'/0/**", slip44_id=0)
self.assertEqualSchema(schema_wildcard, schema_wildcard.copy())
def test_parse(self):
schema_parsed = PathSchema.parse("m/44'/0-5'/0,1,2'/0/**", slip44_id=0)
schema_manual = PathSchema(
[
(H_(44),),
Interval(H_(0), H_(5)),
set((H_(0), H_(1), H_(2))),
(0,),
],
trailing_components=Interval(0, 0xFFFF_FFFF),
)
self.assertEqualSchema(schema_manual, schema_parsed)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()