diff --git a/common/protob/messages-management.proto b/common/protob/messages-management.proto index dd914a0d4..adde7b885 100644 --- a/common/protob/messages-management.proto +++ b/common/protob/messages-management.proto @@ -429,7 +429,6 @@ message DoPreauthorized { * @start * @next SignTx * @next GetOwnershipProof - * @next GetPublicKey */ message PreauthorizedRequest { } diff --git a/core/.changelog.d/2289.added b/core/.changelog.d/2289.added new file mode 100644 index 000000000..3f2663d7a --- /dev/null +++ b/core/.changelog.d/2289.added @@ -0,0 +1 @@ +Add SLIP-0025 CoinJoin accounts. diff --git a/core/src/apps/base.py b/core/src/apps/base.py index c9f5e4cac..56da6f80e 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -4,7 +4,7 @@ import storage.cache import storage.device from trezor import config, utils, wire, workflow from trezor.enums import MessageType -from trezor.messages import Success +from trezor.messages import Success, UnlockPath from . import workflow_handlers @@ -219,6 +219,55 @@ async def handle_DoPreauthorized( return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments] +async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.MessageType: + from trezor.crypto import hmac + from trezor.messages import UnlockedPathRequest + from trezor.ui.layouts import confirm_action + from apps.common.paths import SLIP25_PURPOSE + from apps.common.seed import Slip21Node, get_seed + from apps.common.writers import write_uint32_le + + _KEYCHAIN_MAC_KEY_PATH = [b"TREZOR", b"Keychain MAC key"] + + # UnlockPath is relevant only for SLIP-25 paths. + # Note: Currently we only allow unlocking the entire SLIP-25 purpose subtree instead of + # per-coin or per-account unlocking in order to avoid UI complexity. + if msg.address_n != [SLIP25_PURPOSE]: + raise wire.DataError("Invalid path") + + seed = await get_seed(ctx) + node = Slip21Node(seed) + node.derive_path(_KEYCHAIN_MAC_KEY_PATH) + mac = utils.HashWriter(hmac(hmac.SHA256, node.key())) + for i in msg.address_n: + write_uint32_le(mac, i) + expected_mac = mac.get_digest() + + # Require confirmation to access SLIP25 paths unless already authorized. + if msg.mac: + if len(msg.mac) != len(expected_mac) or not utils.consteq( + expected_mac, msg.mac + ): + raise wire.DataError("Invalid MAC") + else: + await confirm_action( + ctx, + "confirm_coinjoin_access", + title="CoinJoin account", + description="Do you want to allow access to CoinJoin accounts?", + ) + + wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx) + req = await ctx.call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) + + assert req.MESSAGE_WIRE_TYPE in wire_types + handler = workflow_handlers.find_registered_handler( + ctx.iface, req.MESSAGE_WIRE_TYPE + ) + assert handler is not None + return await handler(ctx, req, msg) # type: ignore [Expected 2 positional arguments] + + async def handle_CancelAuthorization( ctx: wire.Context, msg: CancelAuthorization ) -> protobuf.MessageType: @@ -336,6 +385,7 @@ def boot() -> None: workflow_handlers.register(MessageType.EndSession, handle_EndSession) workflow_handlers.register(MessageType.Ping, handle_Ping) workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized) + workflow_handlers.register(MessageType.UnlockPath, handle_UnlockPath) workflow_handlers.register( MessageType.CancelAuthorization, handle_CancelAuthorization ) diff --git a/core/src/apps/bitcoin/authorize_coinjoin.py b/core/src/apps/bitcoin/authorize_coinjoin.py index fb50a07f5..d73d7ba90 100644 --- a/core/src/apps/bitcoin/authorize_coinjoin.py +++ b/core/src/apps/bitcoin/authorize_coinjoin.py @@ -8,7 +8,8 @@ from trezor.strings import format_amount from trezor.ui.layouts import confirm_action, confirm_coinjoin, confirm_metadata from apps.common import authorization, safety_checks -from apps.common.paths import validate_path +from apps.common.keychain import FORBIDDEN_KEY_PATH +from apps.common.paths import SLIP25_PURPOSE, validate_path from .authorization import FEE_RATE_DECIMALS from .common import BIP32_WALLET_DEPTH @@ -47,6 +48,9 @@ async def authorize_coinjoin( if not msg.address_n: raise wire.DataError("Empty path not allowed.") + if msg.address_n[0] != SLIP25_PURPOSE and safety_checks.is_strict(): + raise FORBIDDEN_KEY_PATH + await confirm_action( ctx, "coinjoin_coordinator", @@ -57,19 +61,19 @@ async def authorize_coinjoin( icon=ui.ICON_RECOVERY, ) - max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3) - await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte) - validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH await validate_path( ctx, keychain, validation_path, + msg.address_n[0] == SLIP25_PURPOSE, validate_path_against_script_type( coin, address_n=validation_path, script_type=msg.script_type ), ) + max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3) + if msg.max_fee_per_kvbyte > coin.maxfee_kb: await confirm_metadata( ctx, @@ -80,6 +84,8 @@ async def authorize_coinjoin( ButtonRequestType.FeeOverThreshold, ) + await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte) + authorization.set(msg) return Success(message="CoinJoin authorized") diff --git a/core/src/apps/bitcoin/get_public_key.py b/core/src/apps/bitcoin/get_public_key.py index 5db6a4fa5..bb7acbf89 100644 --- a/core/src/apps/bitcoin/get_public_key.py +++ b/core/src/apps/bitcoin/get_public_key.py @@ -2,21 +2,33 @@ from typing import TYPE_CHECKING from trezor import wire from trezor.enums import InputScriptType -from trezor.messages import HDNodeType, PublicKey +from trezor.messages import HDNodeType, PublicKey, UnlockPath from apps.common import coininfo, paths -from apps.common.keychain import get_keychain +from apps.common.keychain import FORBIDDEN_KEY_PATH, get_keychain if TYPE_CHECKING: from trezor.messages import GetPublicKey + from trezor.protobuf import MessageType -async def get_public_key(ctx: wire.Context, msg: GetPublicKey) -> PublicKey: +async def get_public_key( + ctx: wire.Context, msg: GetPublicKey, auth_msg: MessageType | None = None +) -> PublicKey: coin_name = msg.coin_name or "Bitcoin" script_type = msg.script_type or InputScriptType.SPENDADDRESS coin = coininfo.by_name(coin_name) curve_name = msg.ecdsa_curve_name or coin.curve_name + if msg.address_n and msg.address_n[0] == paths.SLIP25_PURPOSE: + # UnlockPath is required to access SLIP25 paths. + if not UnlockPath.is_type_of(auth_msg): + raise FORBIDDEN_KEY_PATH + + # Verify that the desired path lies in the unlocked subtree. + if auth_msg.address_n != msg.address_n[: len(auth_msg.address_n)]: + raise FORBIDDEN_KEY_PATH + keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema]) node = keychain.derive(msg.address_n) diff --git a/core/src/apps/bitcoin/keychain.py b/core/src/apps/bitcoin/keychain.py index 59176b236..539bf090c 100644 --- a/core/src/apps/bitcoin/keychain.py +++ b/core/src/apps/bitcoin/keychain.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from trezor import wire from trezor.enums import InputScriptType +from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx, UnlockPath from apps.common import coininfo from apps.common.keychain import get_keychain @@ -19,13 +20,10 @@ if TYPE_CHECKING: from trezor.protobuf import MessageType from trezor.messages import ( - AuthorizeCoinJoin, GetAddress, GetOwnershipId, - GetOwnershipProof, GetPublicKey, SignMessage, - SignTx, VerifyMessage, ) @@ -66,6 +64,10 @@ PATTERN_BIP49 = "m/49'/coin_type'/account'/change/address_index" PATTERN_BIP84 = "m/84'/coin_type'/account'/change/address_index" # BIP-86 for taproot: https://github.com/bitcoin/bips/blob/master/bip-0086.mediawiki PATTERN_BIP86 = "m/86'/coin_type'/account'/change/address_index" +# SLIP-25 for CoinJoin: https://github.com/satoshilabs/slips/blob/master/slip-0025.md +# Only account=0 and script_type=1 are supported for now. +PATTERN_SLIP25_TAPROOT = "m/10025'/coin_type'/0'/1'/change/address_index" +PATTERN_SLIP25_TAPROOT_EXTERNAL = "m/10025'/coin_type'/0'/1'/0/address_index" # compatibility patterns, will be removed in the future PATTERN_GREENADDRESS_A = "m/[1,4]/address_index" @@ -151,13 +153,16 @@ def validate_path_against_script_type( elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT: patterns.append(PATTERN_BIP86) + patterns.append(PATTERN_SLIP25_TAPROOT) return any( PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns ) -def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]: +def get_schemas_for_coin( + coin: coininfo.CoinInfo, unlock_schemas: Iterable[PathSchema] = () +) -> Iterable[PathSchema]: # basic patterns patterns = [ PATTERN_BIP44, @@ -206,6 +211,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]: if coin.taproot: patterns.append(PATTERN_BIP86) + schemas = get_schemas_from_patterns(patterns, coin) + schemas.extend(unlock_schemas) + + gc.collect() + return [schema.copy() for schema in schemas] + + +def get_schemas_from_patterns( + patterns: Iterable[str], coin: coininfo.CoinInfo +) -> list[PathSchema]: schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns] # Some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths. @@ -219,8 +234,7 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]: PathSchema.parse(pattern, SLIP44_BITCOIN) for pattern in patterns ) - gc.collect() - return [schema.copy() for schema in schemas] + return schemas def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: @@ -234,13 +248,52 @@ def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: async def get_keychain_for_coin( - ctx: wire.Context, coin_name: str | None -) -> tuple[Keychain, coininfo.CoinInfo]: - coin = get_coin_by_name(coin_name) - schemas = get_schemas_for_coin(coin) + ctx: wire.Context, + coin: coininfo.CoinInfo, + unlock_schemas: Iterable[PathSchema] = (), +) -> Keychain: + schemas = get_schemas_for_coin(coin, unlock_schemas) slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]] keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces) - return keychain, coin + return keychain + + +def _get_unlock_schemas( + msg: MessageType, auth_msg: MessageType | None, coin: coininfo.CoinInfo +) -> list[PathSchema]: + """ + Provides additional keychain schemas that are unlocked by the particular + combination of `msg` and `auth_msg`. + """ + + if AuthorizeCoinJoin.is_type_of(msg): + # When processing the AuthorizeCoinJoin message, validate_path() always + # needs to treat SLIP-25 paths as valid, so add SLIP-25 to the schemas. + return get_schemas_from_patterns([PATTERN_SLIP25_TAPROOT], coin) + + if AuthorizeCoinJoin.is_type_of(auth_msg) or UnlockPath.is_type_of(auth_msg): + # The user has preauthorized access to certain paths. Here we create a + # list of all the patterns that can be unlocked by AuthorizeCoinJoin or + # by UnlockPath. At the moment only SLIP-25 paths can be unlocked. + patterns = [] + if SignTx.is_type_of(msg) or GetOwnershipProof.is_type_of(msg): + # SignTx and GetOwnershipProof need access to all SLIP-25 addresses + # to create CoinJoin outputs. + patterns.append(PATTERN_SLIP25_TAPROOT) + else: + # In case of other messages like GetAddress or SignMessage there is + # no reason for the user to work with SLIP-25 change-addresses. For + # example, using a change-address to receive a payment may + # compromise privacy. + patterns.append(PATTERN_SLIP25_TAPROOT_EXTERNAL) + + # Convert the unlockable patterns to schemas and select only the ones + # that are unlocked by the auth_msg, i.e. lie in a subtree of the + # auth_msg's path. + schemas = get_schemas_from_patterns(patterns, coin) + return [s for s in schemas if s.restrict(auth_msg.address_n)] + + return [] def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]: @@ -249,8 +302,10 @@ def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]: msg: MsgIn, auth_msg: MessageType | None = None, ) -> MsgOut: - keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) - if auth_msg: + coin = get_coin_by_name(msg.coin_name) + unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin) + keychain = await get_keychain_for_coin(ctx, coin, unlock_schemas) + if AuthorizeCoinJoin.is_type_of(auth_msg): auth_obj = authorization.from_cached_message(auth_msg) return await func(ctx, msg, keychain, coin, auth_obj) else: diff --git a/core/src/apps/common/paths.py b/core/src/apps/common/paths.py index 64218f6b1..eeb2a2175 100644 --- a/core/src/apps/common/paths.py +++ b/core/src/apps/common/paths.py @@ -2,6 +2,7 @@ from micropython import const from typing import TYPE_CHECKING HARDENED = const(0x8000_0000) +SLIP25_PURPOSE = const(10025 | HARDENED) if TYPE_CHECKING: from typing import ( @@ -246,6 +247,41 @@ class PathSchema: return True + def set_never_matching(self) -> None: + """Sets the schema to never match any paths.""" + self.schema = [] + self.trailing_components = self._EMPTY_TUPLE + + def restrict(self, path: Bip32Path) -> bool: + """ + Restricts the schema to patterns that are prefixed by the specified + path. If the restriction results in a never-matching schema, then False + is returned. + """ + + for i, value in enumerate(path): + if i < len(self.schema): + # Ensure that the path is a prefix of the schema. + if value not in self.schema[i]: + self.set_never_matching() + return False + + # Restrict the schema component if there are multiple choices. + component = self.schema[i] + if not isinstance(component, tuple) or len(component) != 1: + self.schema[i] = (value,) + else: + # The path is longer than the schema. We need to restrict the + # trailing components. + + if value not in self.trailing_components: + self.set_never_matching() + return False + + self.schema.append((value,)) + + return True + if __debug__: def __repr__(self) -> str: diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index 695a272e9..2156bd24b 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -4,7 +4,7 @@ from trezor import wire from trezor.crypto import bip39 from apps.common.paths import HARDENED -from apps.bitcoin.keychain import get_keychain_for_coin +from apps.bitcoin.keychain import get_coin_by_name, get_keychain_for_coin class TestBitcoinKeychain(unittest.TestCase): @@ -14,9 +14,8 @@ class TestBitcoinKeychain(unittest.TestCase): cache.set(cache.APP_COMMON_SEED, seed) def test_bitcoin(self): - keychain, coin = await_result( - get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bitcoin") - ) + coin = get_coin_by_name("Bitcoin") + keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) self.assertEqual(coin.coin_name, "Bitcoin") valid_addresses = ( @@ -46,9 +45,8 @@ class TestBitcoinKeychain(unittest.TestCase): self.assertRaises(wire.DataError, keychain.derive, addr) def test_testnet(self): - keychain, coin = await_result( - get_keychain_for_coin(wire.DUMMY_CONTEXT, "Testnet") - ) + coin = get_coin_by_name("Testnet") + keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) self.assertEqual(coin.coin_name, "Testnet") valid_addresses = ( @@ -78,13 +76,14 @@ class TestBitcoinKeychain(unittest.TestCase): self.assertRaises(wire.DataError, keychain.derive, addr) def test_unspecified(self): - keychain, coin = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, None)) + coin = get_coin_by_name(None) + keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) self.assertEqual(coin.coin_name, "Bitcoin") keychain.derive([H_(44), H_(0), H_(0), 0, 0]) def test_unknown(self): with self.assertRaises(wire.DataError): - await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, "MadeUpCoin2020")) + get_coin_by_name("MadeUpCoin2020") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @@ -95,9 +94,8 @@ class TestAltcoinKeychains(unittest.TestCase): cache.set(cache.APP_COMMON_SEED, seed) def test_bcash(self): - keychain, coin = await_result( - get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bcash") - ) + coin = get_coin_by_name("Bcash") + keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) self.assertEqual(coin.coin_name, "Bcash") self.assertFalse(coin.segwit) @@ -133,9 +131,8 @@ class TestAltcoinKeychains(unittest.TestCase): self.assertRaises(wire.DataError, keychain.derive, addr) def test_litecoin(self): - keychain, coin = await_result( - get_keychain_for_coin(wire.DUMMY_CONTEXT, "Litecoin") - ) + coin = get_coin_by_name("Litecoin") + keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin)) self.assertEqual(coin.coin_name, "Litecoin") self.assertTrue(coin.segwit) diff --git a/core/tests/test_apps.common.paths.py b/core/tests/test_apps.common.paths.py index fef408428..ce76536d9 100644 --- a/core/tests/test_apps.common.paths.py +++ b/core/tests/test_apps.common.paths.py @@ -174,6 +174,31 @@ class TestPathSchemas(unittest.TestCase): ) self.assertEqualSchema(schema_manual, schema_parsed) + def test_restrict(self): + PATTERN_BIP44 = "m/44'/coin_type'/account'/change/address_index" + + # Restrict coin type to Bitcoin. + schema = PathSchema.parse(PATTERN_BIP44, (0, 145)) + self.assertTrue(schema.restrict([H_(44), H_(0)])) + expected = PathSchema.parse(PATTERN_BIP44, 0) + self.assertEqualSchema(schema, expected) + + # Restrict coin type to Bitcoin Cash and account 2. + schema = PathSchema.parse(PATTERN_BIP44, (0, 145)) + self.assertTrue(schema.restrict([H_(44), H_(145), H_(2)])) + expected = PathSchema.parse("m/44'/145'/2'/change/address_index", 0) + self.assertEqualSchema(schema, expected) + + # Restrict wildcards. + schema = PathSchema.parse("m/10018'/**", 0) + self.assertTrue(schema.restrict([H_(10018), H_(3), 7])) + expected = PathSchema.parse("m/10018'/3'/7/**", 0) + self.assertEqualSchema(schema, expected) + + # Restrict to a never-matching schema. + schema = PathSchema.parse(PATTERN_BIP44, (0, 145)) + self.assertFalse(schema.restrict([H_(44), H_(0), 0])) + if __name__ == "__main__": unittest.main()