mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-19 12:58:13 +00:00
refactor(core): defragment PathSchema memory usage
This commit is contained in:
parent
e5a481ded5
commit
f3db4f2dd3
@ -1,3 +1,5 @@
|
|||||||
|
import gc
|
||||||
|
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.messages import InputScriptType as I
|
from trezor.messages import InputScriptType as I
|
||||||
|
|
||||||
@ -113,7 +115,7 @@ def validate_path_against_script_type(
|
|||||||
patterns.append(PATTERN_GREENADDRESS_B)
|
patterns.append(PATTERN_GREENADDRESS_B)
|
||||||
|
|
||||||
return any(
|
return any(
|
||||||
PathSchema(pattern, coin.slip44).match(address_n) for pattern in patterns
|
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -151,15 +153,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
schemas = [PathSchema(pattern, coin.slip44) for pattern in patterns]
|
schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
|
||||||
|
|
||||||
# some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths
|
# some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths
|
||||||
# we can allow spending these coins from Bitcoin paths if the coin has
|
# we can allow spending these coins from Bitcoin paths if the coin has
|
||||||
# implemented strong replay protection via SIGHASH_FORKID
|
# implemented strong replay protection via SIGHASH_FORKID
|
||||||
if coin.fork_id is not None:
|
if coin.fork_id is not None:
|
||||||
schemas.extend(PathSchema(pattern, 0) for pattern in patterns)
|
schemas.extend(PathSchema.parse(pattern, 0) for pattern in patterns)
|
||||||
|
|
||||||
return schemas
|
gc.collect()
|
||||||
|
return [schema.copy() for schema in schemas]
|
||||||
|
|
||||||
|
|
||||||
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
|
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
|
||||||
|
@ -9,11 +9,11 @@ BYRON_ROOT = [44 | HARDENED, SLIP44_ID | HARDENED]
|
|||||||
SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED]
|
SHELLEY_ROOT = [1852 | HARDENED, SLIP44_ID | HARDENED]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
SCHEMA_PUBKEY = PathSchema("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
|
SCHEMA_PUBKEY = PathSchema.parse("m/[44,1852]'/coin_type'/account'/*", SLIP44_ID)
|
||||||
SCHEMA_ADDRESS = PathSchema("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
|
SCHEMA_ADDRESS = PathSchema.parse("m/[44,1852]'/coin_type'/account'/[0,1,2]/address_index", SLIP44_ID)
|
||||||
# staking is only allowed on Shelley paths with suffix /2/0
|
# staking is only allowed on Shelley paths with suffix /2/0
|
||||||
SCHEMA_STAKING = PathSchema("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
|
SCHEMA_STAKING = PathSchema.parse("m/1852'/coin_type'/account'/2/0", SLIP44_ID)
|
||||||
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
|
SCHEMA_STAKING_ANY_ACCOUNT = PathSchema.parse("m/1852'/coin_type'/[0-%s]'/2/0" % (HARDENED - 1), SLIP44_ID)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# the maximum allowed change address. this should be large enough for normal
|
# the maximum allowed change address. this should be large enough for normal
|
||||||
|
@ -197,7 +197,8 @@ def with_slip44_keychain(
|
|||||||
|
|
||||||
schemas = []
|
schemas = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
schemas.append(paths.PathSchema(pattern=pattern, slip44_id=slip44_ids))
|
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
|
||||||
|
schemas = [s.copy() for s in schemas]
|
||||||
|
|
||||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||||
|
@ -104,7 +104,63 @@ class PathSchema:
|
|||||||
"**": Interval(0, 0xFFFF_FFFF),
|
"**": Interval(0, 0xFFFF_FFFF),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, pattern: str, slip44_id: int | Iterable[int]) -> None:
|
_EMPTY_TUPLE = ()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_hardened(s: str) -> int:
|
||||||
|
return int(s) | HARDENED
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _copy_container(container: Container[int]) -> Container[int]:
|
||||||
|
if isinstance(container, Interval):
|
||||||
|
return Interval(container.min, container.max)
|
||||||
|
if isinstance(container, set):
|
||||||
|
return set(container)
|
||||||
|
if isinstance(container, tuple):
|
||||||
|
return container[:]
|
||||||
|
raise RuntimeError("Unsupported container for copy")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
schema: list[Container[int]],
|
||||||
|
trailing_components: Container[int] = (),
|
||||||
|
compact: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new PathSchema from a list of containers and trailing components.
|
||||||
|
|
||||||
|
Mainly for internal use in `PathSchema.parse`, which is the method you should
|
||||||
|
be using.
|
||||||
|
|
||||||
|
Can be used to create a schema manually without parsing a path string:
|
||||||
|
|
||||||
|
>>> SCHEMA_MINE = PathSchema([
|
||||||
|
>>> (44 | HARDENED,),
|
||||||
|
>>> (0 | HARDENED,),
|
||||||
|
>>> Interval(0 | HARDENED, 10 | HARDENED),
|
||||||
|
>>> ],
|
||||||
|
>>> trailing_components=Interval(0, 0xFFFF_FFFF),
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
Setting `compact=True` creates a compact copy of the provided components, so
|
||||||
|
as to prevent memory fragmentation.
|
||||||
|
"""
|
||||||
|
if compact:
|
||||||
|
self.schema: list[Container[int]] = [self._EMPTY_TUPLE] * len(schema)
|
||||||
|
for i in range(len(schema)):
|
||||||
|
self.schema[i] = self._copy_container(schema[i])
|
||||||
|
self.trailing_components = self._copy_container(trailing_components)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.schema = schema
|
||||||
|
self.trailing_components = trailing_components
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, pattern: str, slip44_id: int | Iterable[int]) -> "PathSchema":
|
||||||
|
"""Parse a path schema string into a PathSchema instance.
|
||||||
|
|
||||||
|
The parsing process trashes the memory layout, so at the end a compact-allocated
|
||||||
|
copy of the resulting structures is returned.
|
||||||
|
"""
|
||||||
if not pattern.startswith("m/"):
|
if not pattern.startswith("m/"):
|
||||||
raise ValueError # unsupported path template
|
raise ValueError # unsupported path template
|
||||||
components = pattern[2:].split("/")
|
components = pattern[2:].split("/")
|
||||||
@ -112,24 +168,24 @@ class PathSchema:
|
|||||||
if isinstance(slip44_id, int):
|
if isinstance(slip44_id, int):
|
||||||
slip44_id = (slip44_id,)
|
slip44_id = (slip44_id,)
|
||||||
|
|
||||||
self.schema: list[Container[int]] = []
|
schema: list[Container[int]] = []
|
||||||
self.trailing_components: Container[int] = ()
|
trailing_components: Container[int] = ()
|
||||||
|
|
||||||
for component in components:
|
for component in components:
|
||||||
if component in self.WILDCARD_RANGES:
|
if component in cls.WILDCARD_RANGES:
|
||||||
if len(self.schema) != len(components) - 1:
|
if len(schema) != len(components) - 1:
|
||||||
# every component should have resulted in extending self.schema
|
# every component should have resulted in extending schema
|
||||||
# so if self.schema does not have the appropriate length (yet),
|
# so if schema does not have the appropriate length (yet),
|
||||||
# the asterisk is not the last item
|
# the asterisk is not the last item
|
||||||
raise ValueError # asterisk is not last item of pattern
|
raise ValueError # asterisk is not last item of pattern
|
||||||
|
|
||||||
self.trailing_components = self.WILDCARD_RANGES[component]
|
trailing_components = cls.WILDCARD_RANGES[component]
|
||||||
break
|
break
|
||||||
|
|
||||||
# figure out if the component is hardened
|
# figure out if the component is hardened
|
||||||
if component[-1] == "'":
|
if component[-1] == "'":
|
||||||
component = component[:-1]
|
component = component[:-1]
|
||||||
parse: Callable[[Any], int] = lambda s: int(s) | HARDENED # noqa: E731
|
parse: Callable[[Any], int] = cls._parse_hardened
|
||||||
else:
|
else:
|
||||||
parse = int
|
parse = int
|
||||||
|
|
||||||
@ -138,24 +194,37 @@ class PathSchema:
|
|||||||
component = component[1:-1]
|
component = component[1:-1]
|
||||||
|
|
||||||
# optionally replace a keyword
|
# optionally replace a keyword
|
||||||
component = self.REPLACEMENTS.get(component, component)
|
component = cls.REPLACEMENTS.get(component, component)
|
||||||
|
|
||||||
if "-" in component:
|
if "-" in component:
|
||||||
# parse as a range
|
# parse as a range
|
||||||
a, b = [parse(s) for s in component.split("-", 1)]
|
a, b = [parse(s) for s in component.split("-", 1)]
|
||||||
self.schema.append(Interval(a, b))
|
schema.append(Interval(a, b))
|
||||||
|
|
||||||
elif "," in component:
|
elif "," in component:
|
||||||
# parse as a list of values
|
# parse as a list of values
|
||||||
self.schema.append(set(parse(s) for s in component.split(",")))
|
schema.append(set(parse(s) for s in component.split(",")))
|
||||||
|
|
||||||
elif component == "coin_type":
|
elif component == "coin_type":
|
||||||
# substitute SLIP-44 ids
|
# substitute SLIP-44 ids
|
||||||
self.schema.append(set(parse(s) for s in slip44_id))
|
schema.append(set(parse(s) for s in slip44_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# plain constant
|
# plain constant
|
||||||
self.schema.append((parse(component),))
|
schema.append((parse(component),))
|
||||||
|
|
||||||
|
return cls(schema, trailing_components, compact=True)
|
||||||
|
|
||||||
|
def copy(self) -> "PathSchema":
|
||||||
|
"""Create a compact copy of the schema.
|
||||||
|
|
||||||
|
Useful when creating multiple schemas in a row. The following code ensures
|
||||||
|
that the set of schemas is allocated in a contiguous block of memory:
|
||||||
|
|
||||||
|
>>> some_schemas = make_multiple_schemas()
|
||||||
|
>>> some_schemas = [s.copy() for s in some_schemas]
|
||||||
|
"""
|
||||||
|
return PathSchema(self.schema, self.trailing_components, compact=True)
|
||||||
|
|
||||||
def match(self, path: Bip32Path) -> bool:
|
def match(self, path: Bip32Path) -> bool:
|
||||||
# The path must not be _shorter_ than schema. It may be longer.
|
# The path must not be _shorter_ than schema. It may be longer.
|
||||||
|
@ -41,7 +41,8 @@ def _schemas_from_address_n(
|
|||||||
return ()
|
return ()
|
||||||
|
|
||||||
slip44_id = slip44_hardened - HARDENED
|
slip44_id = slip44_hardened - HARDENED
|
||||||
return (paths.PathSchema(pattern, slip44_id) for pattern in patterns)
|
schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
|
||||||
|
return [s.copy() for s in schemas]
|
||||||
|
|
||||||
|
|
||||||
def with_keychain_from_path(
|
def with_keychain_from_path(
|
||||||
@ -79,7 +80,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]:
|
|||||||
else:
|
else:
|
||||||
slip44_id = (info.slip44,)
|
slip44_id = (info.slip44,)
|
||||||
|
|
||||||
return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS)
|
schemas = [
|
||||||
|
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
|
||||||
|
]
|
||||||
|
return [s.copy() for s in schemas]
|
||||||
|
|
||||||
|
|
||||||
def with_keychain_from_chain_id(
|
def with_keychain_from_chain_id(
|
||||||
|
@ -3,9 +3,8 @@ from common import *
|
|||||||
from mock_storage import mock_storage
|
from mock_storage import mock_storage
|
||||||
|
|
||||||
from storage import cache
|
from storage import cache
|
||||||
import storage.device
|
|
||||||
from apps.common import safety_checks
|
from apps.common import safety_checks
|
||||||
from apps.common.paths import PATTERN_SEP5, PathSchema, path_is_hardened
|
from apps.common.paths import PATTERN_SEP5, PathSchema
|
||||||
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
from apps.common.keychain import LRUCache, Keychain, with_slip44_keychain, get_keychain
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
@ -22,8 +21,8 @@ class TestKeychain(unittest.TestCase):
|
|||||||
@mock_storage
|
@mock_storage
|
||||||
def test_verify_path(self):
|
def test_verify_path(self):
|
||||||
schemas = (
|
schemas = (
|
||||||
PathSchema("m/44'/coin_type'", slip44_id=134),
|
PathSchema.parse("m/44'/coin_type'", slip44_id=134),
|
||||||
PathSchema("m/44'/coin_type'", slip44_id=11),
|
PathSchema.parse("m/44'/coin_type'", slip44_id=11),
|
||||||
)
|
)
|
||||||
keychain = Keychain(b"", "secp256k1", schemas)
|
keychain = Keychain(b"", "secp256k1", schemas)
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ class TestKeychain(unittest.TestCase):
|
|||||||
keychain.verify_path(path)
|
keychain.verify_path(path)
|
||||||
|
|
||||||
def test_verify_path_special_ed25519(self):
|
def test_verify_path_special_ed25519(self):
|
||||||
schema = PathSchema("m/44'/coin_type'/*", slip44_id=134)
|
schema = PathSchema.parse("m/44'/coin_type'/*", slip44_id=134)
|
||||||
k = Keychain(b"", "ed25519-keccak", [schema])
|
k = Keychain(b"", "ed25519-keccak", [schema])
|
||||||
|
|
||||||
# OK case
|
# OK case
|
||||||
@ -74,7 +73,7 @@ class TestKeychain(unittest.TestCase):
|
|||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache.set(cache.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
schema = PathSchema("m/44'/1'", 0)
|
schema = PathSchema.parse("m/44'/1'", 0)
|
||||||
keychain = await_result(
|
keychain = await_result(
|
||||||
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
|
get_keychain(wire.DUMMY_CONTEXT, "secp256k1", [schema])
|
||||||
)
|
)
|
||||||
|
@ -68,7 +68,7 @@ class TestPathSchemas(unittest.TestCase):
|
|||||||
|
|
||||||
def test_pattern_fixed(self):
|
def test_pattern_fixed(self):
|
||||||
pattern = "m/44'/0'/0'/0/0"
|
pattern = "m/44'/0'/0'/0/0"
|
||||||
schema = PathSchema(pattern, 0)
|
schema = PathSchema.parse(pattern, 0)
|
||||||
|
|
||||||
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
|
self.assertMatch(schema, [H_(44), H_(0), H_(0), 0, 0])
|
||||||
|
|
||||||
@ -88,8 +88,8 @@ class TestPathSchemas(unittest.TestCase):
|
|||||||
def test_ranges_sets(self):
|
def test_ranges_sets(self):
|
||||||
pattern_ranges = "m/44'/[100-109]'/[0-20]"
|
pattern_ranges = "m/44'/[100-109]'/[0-20]"
|
||||||
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
|
pattern_sets = "m/44'/[100,105,109]'/[0,10,20]"
|
||||||
schema_ranges = PathSchema(pattern_ranges, 0)
|
schema_ranges = PathSchema.parse(pattern_ranges, 0)
|
||||||
schema_sets = PathSchema(pattern_sets, 0)
|
schema_sets = PathSchema.parse(pattern_sets, 0)
|
||||||
|
|
||||||
paths_good = [
|
paths_good = [
|
||||||
[H_(44), H_(100), 0],
|
[H_(44), H_(100), 0],
|
||||||
@ -125,13 +125,13 @@ class TestPathSchemas(unittest.TestCase):
|
|||||||
def test_brackets(self):
|
def test_brackets(self):
|
||||||
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
|
pattern_a = "m/[0]'/[0-5]'/[0,1,2]'/[0]/[0-5]/[0,1,2]"
|
||||||
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
|
pattern_b = "m/0'/0-5'/0,1,2'/0/0-5/0,1,2"
|
||||||
schema_a = PathSchema(pattern_a, 0)
|
schema_a = PathSchema.parse(pattern_a, 0)
|
||||||
schema_b = PathSchema(pattern_b, 0)
|
schema_b = PathSchema.parse(pattern_b, 0)
|
||||||
self.assertEqualSchema(schema_a, schema_b)
|
self.assertEqualSchema(schema_a, schema_b)
|
||||||
|
|
||||||
def test_wildcard(self):
|
def test_wildcard(self):
|
||||||
pattern = "m/44'/0'/*"
|
pattern = "m/44'/0'/*"
|
||||||
schema = PathSchema(pattern, 0)
|
schema = PathSchema.parse(pattern, 0)
|
||||||
|
|
||||||
paths_good = [
|
paths_good = [
|
||||||
[H_(44), H_(0)],
|
[H_(44), H_(0)],
|
||||||
@ -152,12 +152,32 @@ class TestPathSchemas(unittest.TestCase):
|
|||||||
def test_substitutes(self):
|
def test_substitutes(self):
|
||||||
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
|
pattern_sub = "m/44'/coin_type'/account'/change/address_index"
|
||||||
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
|
pattern_plain = "m/44'/19'/0-100'/0,1/0-1000000"
|
||||||
schema_sub = PathSchema(pattern_sub, slip44_id=19)
|
schema_sub = PathSchema.parse(pattern_sub, slip44_id=19)
|
||||||
# use wrong slip44 id to ensure it doesn't affect anything
|
# use wrong slip44 id to ensure it doesn't affect anything
|
||||||
schema_plain = PathSchema(pattern_plain, slip44_id=0)
|
schema_plain = PathSchema.parse(pattern_plain, slip44_id=0)
|
||||||
|
|
||||||
self.assertEqualSchema(schema_sub, schema_plain)
|
self.assertEqualSchema(schema_sub, schema_plain)
|
||||||
|
|
||||||
|
def test_copy(self):
|
||||||
|
schema_normal = PathSchema.parse("m/44'/0'/0'/0/0", slip44_id=0)
|
||||||
|
self.assertEqualSchema(schema_normal, schema_normal.copy())
|
||||||
|
|
||||||
|
schema_wildcard = PathSchema.parse("m/44'/0'/0'/0/**", slip44_id=0)
|
||||||
|
self.assertEqualSchema(schema_wildcard, schema_wildcard.copy())
|
||||||
|
|
||||||
|
def test_parse(self):
|
||||||
|
schema_parsed = PathSchema.parse("m/44'/0-5'/0,1,2'/0/**", slip44_id=0)
|
||||||
|
schema_manual = PathSchema(
|
||||||
|
[
|
||||||
|
(H_(44),),
|
||||||
|
Interval(H_(0), H_(5)),
|
||||||
|
set((H_(0), H_(1), H_(2))),
|
||||||
|
(0,),
|
||||||
|
],
|
||||||
|
trailing_components=Interval(0, 0xFFFF_FFFF),
|
||||||
|
)
|
||||||
|
self.assertEqualSchema(schema_manual, schema_parsed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user