1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 20:38:10 +00:00

refactor(core): defragment PathSchema memory usage

This commit is contained in:
matejcik 2021-03-19 16:32:55 +01:00 committed by matejcik
parent e5a481ded5
commit f3db4f2dd3
7 changed files with 135 additions and 39 deletions

View File

@ -1,3 +1,5 @@
import gc
from trezor import wire
from trezor.messages import InputScriptType as I
@ -113,7 +115,7 @@ def validate_path_against_script_type(
patterns.append(PATTERN_GREENADDRESS_B)
return any(
PathSchema(pattern, coin.slip44).match(address_n) for pattern in patterns
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
)
@ -151,15 +153,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
)
)
schemas = [PathSchema(pattern, coin.slip44) for pattern in patterns]
schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
# some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths
# we can allow spending these coins from Bitcoin paths if the coin has
# implemented strong replay protection via SIGHASH_FORKID
if coin.fork_id is not None:
schemas.extend(PathSchema(pattern, 0) for pattern in patterns)
schemas.extend(PathSchema.parse(pattern, 0) for pattern in patterns)
return schemas
gc.collect()
return [schema.copy() for schema in schemas]
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:

View File

@ -9,11 +9,11 @@ BYRON_ROOT = [44 | HARDENED, SLIP44_ID | HARDENED]
SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED]
# fmt: off
SCHEMA_PUBKEY = PathSchema("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
SCHEMA_ADDRESS = PathSchema("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
SCHEMA_PUBKEY = PathSchema.parse("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
SCHEMA_ADDRESS = PathSchema.parse("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
# staking is only allowed on Shelley paths with suffix /2/0
SCHEMA_STAKING = PathSchema("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
SCHEMA_STAKING = PathSchema.parse("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema.parse("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
# fmt: on
# the maximum allowed change address. this should be large enough for normal

View File

@ -197,7 +197,8 @@ def with_slip44_keychain(
schemas = []
for pattern in patterns:
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids))
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:

View File

@ -104,7 +104,63 @@ class PathSchema:
"**": Interval(0, 0xFFFF_FFFF),
}
def __init__(self, pattern: str, slip44_id: int | Iterable[int]) -> None:
_EMPTY_TUPLE = ()
@staticmethod
def _parse_hardened(s: str) -> int:
return int(s) | HARDENED
@staticmethod
def _copy_container(container: Container[int]) -> Container[int]:
if isinstance(container, Interval):
return Interval(container.min, container.max)
if isinstance(container, set):
return set(container)
if isinstance(container, tuple):
return container[:]
raise RuntimeError("Unsupported container for copy")
def __init__(
self,
schema: list[Container[int]],
trailing_components: Container[int] = (),
compact: bool = False,
) -> None:
"""Create a new PathSchema from a list of containers and trailing components.
Mainly for internal use in `PathSchema.parse`, which is the method you should
be using.
Can be used to create a schema manually without parsing a path string:
>>> SCHEMA_MINE = PathSchema([
>>> (44 | HARDENED,),
>>> (0 | HARDENED,),
>>> Interval(0 | HARDENED, 10 | HARDENED),
>>> ],
>>> trailing_components=Interval(0, 0xFFFF_FFFF),
>>> )
Setting `compact=True` creates a compact copy of the provided components, so
as to prevent memory fragmentation.
"""
if compact:
self.schema: list[Container[int]] = [self._EMPTY_TUPLE] * len(schema)
for i in range(len(schema)):
self.schema[i] = self._copy_container(schema[i])
self.trailing_components = self._copy_container(trailing_components)
else:
self.schema = schema
self.trailing_components = trailing_components
@classmethod
def parse(cls, pattern: str, slip44_id: int | Iterable[int]) -> "PathSchema":
"""Parse a path schema string into a PathSchema instance.
The parsing process trashes the memory layout, so at the end a compact-allocated
copy of the resulting structures is returned.
"""
if not pattern.startswith("m/"):
raise ValueError # unsupported path template
components = pattern[2:].split("/")
@ -112,24 +168,24 @@ class PathSchema:
if isinstance(slip44_id, int):
slip44_id = (slip44_id,)
self.schema: list[Container[int]] = []
self.trailing_components: Container[int] = ()
schema: list[Container[int]] = []
trailing_components: Container[int] = ()
for component in components:
if component in self.WILDCARD_RANGES:
if len(self.schema) != len(components) - 1:
# every component should have resulted in extending self.schema
# so if self.schema does not have the appropriate length (yet),
if component in cls.WILDCARD_RANGES:
if len(schema) != len(components) - 1:
# every component should have resulted in extending schema
# so if schema does not have the appropriate length (yet),
# the asterisk is not the last item
raise ValueError # asterisk is not last item of pattern
self.trailing_components = self.WILDCARD_RANGES[component]
trailing_components = cls.WILDCARD_RANGES[component]
break
# figure out if the component is hardened
if component[-1] == "'":
component = component[:-1]
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731
parse: Callable[[Any], int] = cls._parse_hardened
else:
parse = int
@ -138,24 +194,37 @@ class PathSchema:
component = component[1:-1]
# optionally replace a keyword
component = self.REPLACEMENTS.get(component, component)
component = cls.REPLACEMENTS.get(component, component)
if "-" in component:
# parse as a range
a, b = [parse(s) for s in component.split("-", 1)]
self.schema.append(Interval(a, b))
schema.append(Interval(a, b))
elif "," in component:
# parse as a list of values
self.schema.append(set(parse(s) for s in component.split(",")))
schema.append(set(parse(s) for s in component.split(",")))
elif component == "coin_type":
# substitute SLIP-44 ids
self.schema.append(set(parse(s) for s in slip44_id))
schema.append(set(parse(s) for s in slip44_id))
else:
# plain constant
self.schema.append((parse(component),))
schema.append((parse(component),))
return cls(schema, trailing_components, compact=True)
def copy(self) -> "PathSchema":
"""Create a compact copy of the schema.
Useful when creating multiple schemas in a row. The following code ensures
that the set of schemas is allocated in a contiguous block of memory:
>>> some_schemas = make_multiple_schemas()
>>> some_schemas = [s.copy() for s in some_schemas]
"""
return PathSchema(self.schema, self.trailing_components, compact=True)
def match(self, path: Bip32Path) -> bool:
# The path must not be _shorter_ than schema. It may be longer.

View File

@ -41,7 +41,8 @@ def _schemas_from_address_n(
return ()
slip44_id = slip44_hardened - HARDENED
return (paths.PathSchema(pattern, slip44_id) for pattern in patterns)
schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
return [s.copy() for s in schemas]
def with_keychain_from_path(
@ -79,7 +80,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
else:
slip44_id = (info.slip44,)
return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS)
schemas = [
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
]
return [s.copy() for s in schemas]
def with_keychain_from_chain_id(

View File

@ -3,9 +3,8 @@ from common import *
from mock_storage import mock_storage
from storage import cache
import storage.device
from apps.common import safety_checks
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened
from apps.common.paths import PATTERN_SEP5, PathSchema
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
from trezor import wire
from trezor.crypto import bip39
@ -22,8 +21,8 @@ class TestKeychain(unittest.TestCase):
@mock_storage
def test_verify_path(self):
schemas = (
PathSchema("m/44'/coin_type'", slip44_id=134),
PathSchema("m/44'/coin_type'", slip44_id=11),
PathSchema.parse("m/44'/coin_type'", slip44_id=134),
PathSchema.parse("m/44'/coin_type'", slip44_id=11),
)
keychain = Keychain(b"", "secp256k1", schemas)
@ -49,7 +48,7 @@ class TestKeychain(unittest.TestCase):
keychain.verify_path(path)
def test_verify_path_special_ed25519(self):
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134)
schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134)
k = Keychain(b"", "ed25519-keccak", [schema])
# OK case
@ -74,7 +73,7 @@ class TestKeychain(unittest.TestCase):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
schema = PathSchema("m/44'/1'", 0)
schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
)

View File

@ -68,7 +68,7 @@ class TestPathSchemas(unittest.TestCase):
def test_pattern_fixed(self):
pattern = "m/44'/0'/0'/0/0"
schema = PathSchema(pattern, 0)
schema = PathSchema.parse(pattern, 0)
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
@ -88,8 +88,8 @@ class TestPathSchemas(unittest.TestCase):
def test_ranges_sets(self):
pattern_ranges = "m/44'/[100-109]'/[0-20]"
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
schema_ranges = PathSchema(pattern_ranges, 0)
schema_sets = PathSchema(pattern_sets, 0)
schema_ranges = PathSchema.parse(pattern_ranges, 0)
schema_sets = PathSchema.parse(pattern_sets, 0)
paths_good = [
[H_(44), H_(100), 0],
@ -125,13 +125,13 @@ class TestPathSchemas(unittest.TestCase):
def test_brackets(self):
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
schema_a = PathSchema(pattern_a, 0)
schema_b = PathSchema(pattern_b, 0)
schema_a = PathSchema.parse(pattern_a, 0)
schema_b = PathSchema.parse(pattern_b, 0)
self.assertEqualSchema(schema_a, schema_b)
def test_wildcard(self):
pattern = "m/44'/0'/*"
schema = PathSchema(pattern, 0)
schema = PathSchema.parse(pattern, 0)
paths_good = [
[H_(44), H_(0)],
@ -152,12 +152,32 @@ class TestPathSchemas(unittest.TestCase):
def test_substitutes(self):
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
schema_sub = PathSchema(pattern_sub, slip44_id=19)
schema_sub = PathSchema.parse(pattern_sub, slip44_id=19)
# use wrong slip44 id to ensure it doesn't affect anything
schema_plain = PathSchema(pattern_plain, slip44_id=0)
schema_plain = PathSchema.parse(pattern_plain, slip44_id=0)
self.assertEqualSchema(schema_sub, schema_plain)
def test_copy(self):
schema_normal = PathSchema.parse("m/44'/0'/0'/0/0", slip44_id=0)
self.assertEqualSchema(schema_normal, schema_normal.copy())
schema_wildcard = PathSchema.parse("m/44'/0'/0'/0/**", slip44_id=0)
self.assertEqualSchema(schema_wildcard, schema_wildcard.copy())
def test_parse(self):
schema_parsed = PathSchema.parse("m/44'/0-5'/0,1,2'/0/**", slip44_id=0)
schema_manual = PathSchema(
[
(H_(44),),
Interval(H_(0), H_(5)),
set((H_(0), H_(1), H_(2))),
(0,),
],
trailing_components=Interval(0, 0xFFFF_FFFF),
)
self.assertEqualSchema(schema_manual, schema_parsed)
if __name__ == "__main__":
unittest.main()