diff --git a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-sha512.h b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-sha512.h index 5ea27f8f7..8f16ed423 100644 --- a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-sha512.h +++ b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-sha512.h @@ -54,7 +54,7 @@ STATIC mp_obj_t mod_trezorcrypto_Sha512_make_new(const mp_obj_type_t *type, return MP_OBJ_FROM_PTR(o); } -/// def hash(self, data: bytes) -> None: +/// def update(self, data: bytes) -> None: /// """ /// Update the hash context with hashed data. /// """ diff --git a/core/mocks/generated/trezorcrypto/__init__.pyi b/core/mocks/generated/trezorcrypto/__init__.pyi index 4c26b5f8f..446b897ab 100644 --- a/core/mocks/generated/trezorcrypto/__init__.pyi +++ b/core/mocks/generated/trezorcrypto/__init__.pyi @@ -358,7 +358,7 @@ class sha512: Creates a hash context object. """ - def hash(self, data: bytes) -> None: + def update(self, data: bytes) -> None: """ Update the hash context with hashed data. """ diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 6129e765c..384333d24 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,11 +1,36 @@ from trezor import wire -from trezor.crypto import bip32 +from trezor.crypto import bip32, hashlib, hmac from apps.common import HARDENED, cache, mnemonic, storage from apps.common.request_passphrase import protect_by_passphrase if False: - from typing import List, Optional + from typing import List, Union + + +class Slip21Node: + def __init__(self, seed: bytes = None) -> None: + if seed is not None: + self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest() + else: + self.data = b"" + + def __del__(self) -> None: + del self.data + + def derive_path(self, path: list) -> None: + for label in path: + h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512) + h.update(label) + self.data = h.digest() + + def key(self) -> bytes: + return self.data[32:64] + + def clone(self) -> "Slip21Node": + node = Slip21Node() + node.data = self.data + return node class Keychain: @@ -17,7 +42,9 @@ class Keychain: def __init__(self, seed: bytes, namespaces: list): self.seed = seed self.namespaces = namespaces - self.roots = [None] * len(namespaces) # type: List[Optional[bip32.HDNode]] + self.roots = [None] * len( + namespaces + ) # type: List[Union[bip32.HDNode, Slip21Node, None]] def __del__(self) -> None: for root in self.roots: @@ -34,7 +61,12 @@ class Keychain: return raise wire.DataError("Forbidden key path") - def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: + def derive( + self, node_path: list, curve_name: str = "secp256k1" + ) -> Union[bip32.HDNode, Slip21Node]: + if "ed25519" in curve_name and not _path_hardened(node_path): + raise wire.DataError("Forbidden key path") + # find the root node index root_index = 0 for curve, *path in self.namespaces: @@ -49,11 +81,13 @@ class Keychain: # create the root node if not cached root = self.roots[root_index] if root is None: - root = bip32.from_seed(self.seed, curve_name) + if curve_name != "slip21": + root = bip32.from_seed(self.seed, curve_name) + else: + root = Slip21Node(self.seed) root.derive_path(path) self.roots[root_index] = root - # TODO check for ed25519? # derive child node from the root node = root.clone() node.derive_path(suffix) @@ -86,6 +120,15 @@ def derive_node_without_passphrase( return node +def derive_slip21_node_without_passphrase(path: list) -> Slip21Node: + if not storage.is_initialized(): + raise Exception("Device is not initialized") + seed = mnemonic.get_seed(progress_bar=False) + node = Slip21Node(seed) + node.derive_path(path) + return node + + def remove_ed25519_prefix(pubkey: bytes) -> bytes: # 0x01 prefix is not part of the actual public key, hence removed return pubkey[1:] diff --git a/core/tests/test_apps.common.seed.py b/core/tests/test_apps.common.seed.py index f8b309735..6e7b6aef4 100644 --- a/core/tests/test_apps.common.seed.py +++ b/core/tests/test_apps.common.seed.py @@ -1,7 +1,8 @@ from common import * from apps.common import HARDENED -from apps.common.seed import Keychain, _path_hardened +from apps.common.seed import Keychain, Slip21Node, _path_hardened from trezor import wire +from trezor.crypto import bip39 class TestKeychain(unittest.TestCase): @@ -75,5 +76,45 @@ class TestKeychain(unittest.TestCase): self.assertFalse(_path_hardened([0, ])) self.assertFalse(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0])) + def test_slip21(self): + seed = bip39.seed(' '.join(['all'] * 12), '') + node1 = Slip21Node(seed) + node2 = node1.clone() + keychain = Keychain(seed, [["slip21", b"SLIP-0021"]]) + + # Key(m) + KEY_M = unhexlify(b"dbf12b44133eaab506a740f6565cc117228cbf1dd70635cfa8ddfdc9af734756") + self.assertEqual(node1.key(), KEY_M) + + # Key(m/"SLIP-0021") + KEY_M_SLIP0021 = unhexlify(b"1d065e3ac1bbe5c7fad32cf2305f7d709dc070d672044a19e610c77cdf33de0d") + node1.derive_path([b"SLIP-0021"]) + self.assertEqual(node1.key(), KEY_M_SLIP0021) + self.assertIsNone(keychain.validate_path([b"SLIP-0021"], "slip21")) + self.assertEqual(keychain.derive([b"SLIP-0021"], "slip21").key(), KEY_M_SLIP0021) + + # Key(m/"SLIP-0021"/"Master encryption key") + KEY_M_SLIP0021_MEK = unhexlify(b"ea163130e35bbafdf5ddee97a17b39cef2be4b4f390180d65b54cf05c6a82fde") + node1.derive_path([b"Master encryption key"]) + self.assertEqual(node1.key(), KEY_M_SLIP0021_MEK) + self.assertIsNone(keychain.validate_path([b"SLIP-0021", b"Master encryption key"], "slip21")) + self.assertEqual(keychain.derive([b"SLIP-0021", b"Master encryption key"], "slip21").key(), KEY_M_SLIP0021_MEK) + + # Key(m/"SLIP-0021"/"Authentication key") + KEY_M_SLIP0021_AK = unhexlify(b"47194e938ab24cc82bfa25f6486ed54bebe79c40ae2a5a32ea6db294d81861a6") + node2.derive_path([b"SLIP-0021", b"Authentication key"]) + self.assertEqual(node2.key(), KEY_M_SLIP0021_AK) + self.assertIsNone(keychain.validate_path([b"SLIP-0021", b"Authentication key"], "slip21")) + self.assertEqual(keychain.derive([b"SLIP-0021", b"Authentication key"], "slip21").key(), KEY_M_SLIP0021_AK) + + # Forbidden paths. + with self.assertRaises(wire.DataError): + self.assertFalse(keychain.validate_path([], "slip21")) + with self.assertRaises(wire.DataError): + self.assertFalse(keychain.validate_path([b"SLIP-9999", b"Authentication key"], "slip21")) + with self.assertRaises(wire.DataError): + keychain.derive([b"SLIP-9999", b"Authentication key"], "slip21").key() + + if __name__ == '__main__': unittest.main() diff --git a/core/tests/test_apps.wallet.signtx.fee_threshold.py b/core/tests/test_apps.wallet.signtx.fee_threshold.py index 5d272d602..3741e5130 100644 --- a/core/tests/test_apps.wallet.signtx.fee_threshold.py +++ b/core/tests/test_apps.wallet.signtx.fee_threshold.py @@ -156,9 +156,8 @@ class TestSignTxFeeThreshold(unittest.TestCase): ] seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') - root = bip32.from_seed(seed, 'secp256k1') - keychain = Keychain([[coin_bitcoin.curve_name]], [root]) + keychain = Keychain(seed, [[coin_bitcoin.curve_name]]) signer = signing.sign_tx(tx, keychain) for request, response in chunks(messages, 2): self.assertEqual(signer.send(request), response)