1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 04:18:10 +00:00

feat(core): Implement SLIP-0025 CoinJoin accounts.

This commit is contained in:
Andrew Kozlik 2022-06-03 11:41:44 +02:00 committed by Andrew Kozlik
parent 9d89c3cb1b
commit 77be3653b4
9 changed files with 218 additions and 37 deletions

View File

@ -429,7 +429,6 @@ message DoPreauthorized {
* @start * @start
* @next SignTx * @next SignTx
* @next GetOwnershipProof * @next GetOwnershipProof
* @next GetPublicKey
*/ */
message PreauthorizedRequest { message PreauthorizedRequest {
} }

View File

@ -0,0 +1 @@
Add SLIP-0025 CoinJoin accounts.

View File

@ -4,7 +4,7 @@ import storage.cache
import storage.device import storage.device
from trezor import config, utils, wire, workflow from trezor import config, utils, wire, workflow
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.messages import Success from trezor.messages import Success, UnlockPath
from . import workflow_handlers from . import workflow_handlers
@ -219,6 +219,55 @@ async def handle_DoPreauthorized(
return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments] return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments]
async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.ui.layouts import confirm_action
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
from apps.common.writers import write_uint32_le
_KEYCHAIN_MAC_KEY_PATH = [b"TREZOR", b"Keychain MAC key"]
# UnlockPath is relevant only for SLIP-25 paths.
# Note: Currently we only allow unlocking the entire SLIP-25 purpose subtree instead of
# per-coin or per-account unlocking in order to avoid UI complexity.
if msg.address_n != [SLIP25_PURPOSE]:
raise wire.DataError("Invalid path")
seed = await get_seed(ctx)
node = Slip21Node(seed)
node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
for i in msg.address_n:
write_uint32_le(mac, i)
expected_mac = mac.get_digest()
# Require confirmation to access SLIP25 paths unless already authorized.
if msg.mac:
if len(msg.mac) != len(expected_mac) or not utils.consteq(
expected_mac, msg.mac
):
raise wire.DataError("Invalid MAC")
else:
await confirm_action(
ctx,
"confirm_coinjoin_access",
title="CoinJoin account",
description="Do you want to allow access to CoinJoin accounts?",
)
wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx)
req = await ctx.call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
)
assert handler is not None
return await handler(ctx, req, msg) # type: ignore [Expected 2 positional arguments]
async def handle_CancelAuthorization( async def handle_CancelAuthorization(
ctx: wire.Context, msg: CancelAuthorization ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType: ) -> protobuf.MessageType:
@ -336,6 +385,7 @@ def boot() -> None:
workflow_handlers.register(MessageType.EndSession, handle_EndSession) workflow_handlers.register(MessageType.EndSession, handle_EndSession)
workflow_handlers.register(MessageType.Ping, handle_Ping) workflow_handlers.register(MessageType.Ping, handle_Ping)
workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized) workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
workflow_handlers.register(MessageType.UnlockPath, handle_UnlockPath)
workflow_handlers.register( workflow_handlers.register(
MessageType.CancelAuthorization, handle_CancelAuthorization MessageType.CancelAuthorization, handle_CancelAuthorization
) )

View File

@ -8,7 +8,8 @@ from trezor.strings import format_amount
from trezor.ui.layouts import confirm_action, confirm_coinjoin, confirm_metadata from trezor.ui.layouts import confirm_action, confirm_coinjoin, confirm_metadata
from apps.common import authorization, safety_checks from apps.common import authorization, safety_checks
from apps.common.paths import validate_path from apps.common.keychain import FORBIDDEN_KEY_PATH
from apps.common.paths import SLIP25_PURPOSE, validate_path
from .authorization import FEE_RATE_DECIMALS from .authorization import FEE_RATE_DECIMALS
from .common import BIP32_WALLET_DEPTH from .common import BIP32_WALLET_DEPTH
@ -47,6 +48,9 @@ async def authorize_coinjoin(
if not msg.address_n: if not msg.address_n:
raise wire.DataError("Empty path not allowed.") raise wire.DataError("Empty path not allowed.")
if msg.address_n[0] != SLIP25_PURPOSE and safety_checks.is_strict():
raise FORBIDDEN_KEY_PATH
await confirm_action( await confirm_action(
ctx, ctx,
"coinjoin_coordinator", "coinjoin_coordinator",
@ -57,19 +61,19 @@ async def authorize_coinjoin(
icon=ui.ICON_RECOVERY, icon=ui.ICON_RECOVERY,
) )
max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3)
await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte)
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path( await validate_path(
ctx, ctx,
keychain, keychain,
validation_path, validation_path,
msg.address_n[0] == SLIP25_PURPOSE,
validate_path_against_script_type( validate_path_against_script_type(
coin, address_n=validation_path, script_type=msg.script_type coin, address_n=validation_path, script_type=msg.script_type
), ),
) )
max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3)
if msg.max_fee_per_kvbyte > coin.maxfee_kb: if msg.max_fee_per_kvbyte > coin.maxfee_kb:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
@ -80,6 +84,8 @@ async def authorize_coinjoin(
ButtonRequestType.FeeOverThreshold, ButtonRequestType.FeeOverThreshold,
) )
await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte)
authorization.set(msg) authorization.set(msg)
return Success(message="CoinJoin authorized") return Success(message="CoinJoin authorized")

View File

@ -2,21 +2,33 @@ from typing import TYPE_CHECKING
from trezor import wire from trezor import wire
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import HDNodeType, PublicKey from trezor.messages import HDNodeType, PublicKey, UnlockPath
from apps.common import coininfo, paths from apps.common import coininfo, paths
from apps.common.keychain import get_keychain from apps.common.keychain import FORBIDDEN_KEY_PATH, get_keychain
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import GetPublicKey from trezor.messages import GetPublicKey
from trezor.protobuf import MessageType
async def get_public_key(ctx: wire.Context, msg: GetPublicKey) -> PublicKey: async def get_public_key(
ctx: wire.Context, msg: GetPublicKey, auth_msg: MessageType | None = None
) -> PublicKey:
coin_name = msg.coin_name or "Bitcoin" coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or InputScriptType.SPENDADDRESS script_type = msg.script_type or InputScriptType.SPENDADDRESS
coin = coininfo.by_name(coin_name) coin = coininfo.by_name(coin_name)
curve_name = msg.ecdsa_curve_name or coin.curve_name curve_name = msg.ecdsa_curve_name or coin.curve_name
if msg.address_n and msg.address_n[0] == paths.SLIP25_PURPOSE:
# UnlockPath is required to access SLIP25 paths.
if not UnlockPath.is_type_of(auth_msg):
raise FORBIDDEN_KEY_PATH
# Verify that the desired path lies in the unlocked subtree.
if auth_msg.address_n != msg.address_n[: len(auth_msg.address_n)]:
raise FORBIDDEN_KEY_PATH
keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema]) keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema])
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
from trezor import wire from trezor import wire
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx, UnlockPath
from apps.common import coininfo from apps.common import coininfo
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
@ -19,13 +20,10 @@ if TYPE_CHECKING:
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.messages import ( from trezor.messages import (
AuthorizeCoinJoin,
GetAddress, GetAddress,
GetOwnershipId, GetOwnershipId,
GetOwnershipProof,
GetPublicKey, GetPublicKey,
SignMessage, SignMessage,
SignTx,
VerifyMessage, VerifyMessage,
) )
@ -66,6 +64,10 @@ PATTERN_BIP49 = "m/49'/coin_type'/account'/change/address_index"
PATTERN_BIP84 = "m/84'/coin_type'/account'/change/address_index" PATTERN_BIP84 = "m/84'/coin_type'/account'/change/address_index"
# BIP-86 for taproot: https://github.com/bitcoin/bips/blob/master/bip-0086.mediawiki # BIP-86 for taproot: https://github.com/bitcoin/bips/blob/master/bip-0086.mediawiki
PATTERN_BIP86 = "m/86'/coin_type'/account'/change/address_index" PATTERN_BIP86 = "m/86'/coin_type'/account'/change/address_index"
# SLIP-25 for CoinJoin: https://github.com/satoshilabs/slips/blob/master/slip-0025.md
# Only account=0 and script_type=1 are supported for now.
PATTERN_SLIP25_TAPROOT = "m/10025'/coin_type'/0'/1'/change/address_index"
PATTERN_SLIP25_TAPROOT_EXTERNAL = "m/10025'/coin_type'/0'/1'/0/address_index"
# compatibility patterns, will be removed in the future # compatibility patterns, will be removed in the future
PATTERN_GREENADDRESS_A = "m/[1,4]/address_index" PATTERN_GREENADDRESS_A = "m/[1,4]/address_index"
@ -151,13 +153,16 @@ def validate_path_against_script_type(
elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT: elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT:
patterns.append(PATTERN_BIP86) patterns.append(PATTERN_BIP86)
patterns.append(PATTERN_SLIP25_TAPROOT)
return any( return any(
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
) )
def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]: def get_schemas_for_coin(
coin: coininfo.CoinInfo, unlock_schemas: Iterable[PathSchema] = ()
) -> Iterable[PathSchema]:
# basic patterns # basic patterns
patterns = [ patterns = [
PATTERN_BIP44, PATTERN_BIP44,
@ -206,6 +211,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
if coin.taproot: if coin.taproot:
patterns.append(PATTERN_BIP86) patterns.append(PATTERN_BIP86)
schemas = get_schemas_from_patterns(patterns, coin)
schemas.extend(unlock_schemas)
gc.collect()
return [schema.copy() for schema in schemas]
def get_schemas_from_patterns(
patterns: Iterable[str], coin: coininfo.CoinInfo
) -> list[PathSchema]:
schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns] schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
# Some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths. # Some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths.
@ -219,8 +234,7 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
PathSchema.parse(pattern, SLIP44_BITCOIN) for pattern in patterns PathSchema.parse(pattern, SLIP44_BITCOIN) for pattern in patterns
) )
gc.collect() return schemas
return [schema.copy() for schema in schemas]
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo: def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
@ -234,13 +248,52 @@ def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
async def get_keychain_for_coin( async def get_keychain_for_coin(
ctx: wire.Context, coin_name: str | None ctx: wire.Context,
) -> tuple[Keychain, coininfo.CoinInfo]: coin: coininfo.CoinInfo,
coin = get_coin_by_name(coin_name) unlock_schemas: Iterable[PathSchema] = (),
schemas = get_schemas_for_coin(coin) ) -> Keychain:
schemas = get_schemas_for_coin(coin, unlock_schemas)
slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]] slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]]
keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces) keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces)
return keychain, coin return keychain
def _get_unlock_schemas(
msg: MessageType, auth_msg: MessageType | None, coin: coininfo.CoinInfo
) -> list[PathSchema]:
"""
Provides additional keychain schemas that are unlocked by the particular
combination of `msg` and `auth_msg`.
"""
if AuthorizeCoinJoin.is_type_of(msg):
# When processing the AuthorizeCoinJoin message, validate_path() always
# needs to treat SLIP-25 paths as valid, so add SLIP-25 to the schemas.
return get_schemas_from_patterns([PATTERN_SLIP25_TAPROOT], coin)
if AuthorizeCoinJoin.is_type_of(auth_msg) or UnlockPath.is_type_of(auth_msg):
# The user has preauthorized access to certain paths. Here we create a
# list of all the patterns that can be unlocked by AuthorizeCoinJoin or
# by UnlockPath. At the moment only SLIP-25 paths can be unlocked.
patterns = []
if SignTx.is_type_of(msg) or GetOwnershipProof.is_type_of(msg):
# SignTx and GetOwnershipProof need access to all SLIP-25 addresses
# to create CoinJoin outputs.
patterns.append(PATTERN_SLIP25_TAPROOT)
else:
# In case of other messages like GetAddress or SignMessage there is
# no reason for the user to work with SLIP-25 change-addresses. For
# example, using a change-address to receive a payment may
# compromise privacy.
patterns.append(PATTERN_SLIP25_TAPROOT_EXTERNAL)
# Convert the unlockable patterns to schemas and select only the ones
# that are unlocked by the auth_msg, i.e. lie in a subtree of the
# auth_msg's path.
schemas = get_schemas_from_patterns(patterns, coin)
return [s for s in schemas if s.restrict(auth_msg.address_n)]
return []
def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
@ -249,8 +302,10 @@ def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
msg: MsgIn, msg: MsgIn,
auth_msg: MessageType | None = None, auth_msg: MessageType | None = None,
) -> MsgOut: ) -> MsgOut:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) coin = get_coin_by_name(msg.coin_name)
if auth_msg: unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin)
keychain = await get_keychain_for_coin(ctx, coin, unlock_schemas)
if AuthorizeCoinJoin.is_type_of(auth_msg):
auth_obj = authorization.from_cached_message(auth_msg) auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj) return await func(ctx, msg, keychain, coin, auth_obj)
else: else:

View File

@ -2,6 +2,7 @@ from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
HARDENED = const(0x8000_0000) HARDENED = const(0x8000_0000)
SLIP25_PURPOSE = const(10025 | HARDENED)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import ( from typing import (
@ -246,6 +247,41 @@ class PathSchema:
return True return True
def set_never_matching(self) -> None:
"""Sets the schema to never match any paths."""
self.schema = []
self.trailing_components = self._EMPTY_TUPLE
def restrict(self, path: Bip32Path) -> bool:
"""
Restricts the schema to patterns that are prefixed by the specified
path. If the restriction results in a never-matching schema, then False
is returned.
"""
for i, value in enumerate(path):
if i < len(self.schema):
# Ensure that the path is a prefix of the schema.
if value not in self.schema[i]:
self.set_never_matching()
return False
# Restrict the schema component if there are multiple choices.
component = self.schema[i]
if not isinstance(component, tuple) or len(component) != 1:
self.schema[i] = (value,)
else:
# The path is longer than the schema. We need to restrict the
# trailing components.
if value not in self.trailing_components:
self.set_never_matching()
return False
self.schema.append((value,))
return True
if __debug__: if __debug__:
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -4,7 +4,7 @@ from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from apps.bitcoin.keychain import get_keychain_for_coin from apps.bitcoin.keychain import get_coin_by_name, get_keychain_for_coin
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
@ -14,9 +14,8 @@ class TestBitcoinKeychain(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
def test_bitcoin(self): def test_bitcoin(self):
keychain, coin = await_result( coin = get_coin_by_name("Bitcoin")
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bitcoin") keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
)
self.assertEqual(coin.coin_name, "Bitcoin") self.assertEqual(coin.coin_name, "Bitcoin")
valid_addresses = ( valid_addresses = (
@ -46,9 +45,8 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_testnet(self): def test_testnet(self):
keychain, coin = await_result( coin = get_coin_by_name("Testnet")
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Testnet") keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
)
self.assertEqual(coin.coin_name, "Testnet") self.assertEqual(coin.coin_name, "Testnet")
valid_addresses = ( valid_addresses = (
@ -78,13 +76,14 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_unspecified(self): def test_unspecified(self):
keychain, coin = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, None)) coin = get_coin_by_name(None)
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bitcoin") self.assertEqual(coin.coin_name, "Bitcoin")
keychain.derive([H_(44), H_(0), H_(0), 0, 0]) keychain.derive([H_(44), H_(0), H_(0), 0, 0])
def test_unknown(self): def test_unknown(self):
with self.assertRaises(wire.DataError): with self.assertRaises(wire.DataError):
await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, "MadeUpCoin2020")) get_coin_by_name("MadeUpCoin2020")
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
@ -95,9 +94,8 @@ class TestAltcoinKeychains(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
def test_bcash(self): def test_bcash(self):
keychain, coin = await_result( coin = get_coin_by_name("Bcash")
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bcash") keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
)
self.assertEqual(coin.coin_name, "Bcash") self.assertEqual(coin.coin_name, "Bcash")
self.assertFalse(coin.segwit) self.assertFalse(coin.segwit)
@ -133,9 +131,8 @@ class TestAltcoinKeychains(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr) self.assertRaises(wire.DataError, keychain.derive, addr)
def test_litecoin(self): def test_litecoin(self):
keychain, coin = await_result( coin = get_coin_by_name("Litecoin")
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Litecoin") keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
)
self.assertEqual(coin.coin_name, "Litecoin") self.assertEqual(coin.coin_name, "Litecoin")
self.assertTrue(coin.segwit) self.assertTrue(coin.segwit)

View File

@ -174,6 +174,31 @@ class TestPathSchemas(unittest.TestCase):
) )
self.assertEqualSchema(schema_manual, schema_parsed) self.assertEqualSchema(schema_manual, schema_parsed)
def test_restrict(self):
PATTERN_BIP44 = "m/44'/coin_type'/account'/change/address_index"
# Restrict coin type to Bitcoin.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertTrue(schema.restrict([H_(44), H_(0)]))
expected = PathSchema.parse(PATTERN_BIP44, 0)
self.assertEqualSchema(schema, expected)
# Restrict coin type to Bitcoin Cash and account 2.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertTrue(schema.restrict([H_(44), H_(145), H_(2)]))
expected = PathSchema.parse("m/44'/145'/2'/change/address_index", 0)
self.assertEqualSchema(schema, expected)
# Restrict wildcards.
schema = PathSchema.parse("m/10018'/**", 0)
self.assertTrue(schema.restrict([H_(10018), H_(3), 7]))
expected = PathSchema.parse("m/10018'/3'/7/**", 0)
self.assertEqualSchema(schema, expected)
# Restrict to a never-matching schema.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertFalse(schema.restrict([H_(44), H_(0), 0]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()