feat(core): Implement SLIP-0025 CoinJoin accounts.

pull/2398/head
Andrew Kozlik 2 years ago committed by Andrew Kozlik
parent 9d89c3cb1b
commit 77be3653b4

@ -429,7 +429,6 @@ message DoPreauthorized {
* @start
* @next SignTx
* @next GetOwnershipProof
* @next GetPublicKey
*/
message PreauthorizedRequest {
}

@ -0,0 +1 @@
Add SLIP-0025 CoinJoin accounts.

@ -4,7 +4,7 @@ import storage.cache
import storage.device
from trezor import config, utils, wire, workflow
from trezor.enums import MessageType
from trezor.messages import Success
from trezor.messages import Success, UnlockPath
from . import workflow_handlers
@ -219,6 +219,55 @@ async def handle_DoPreauthorized(
return await handler(ctx, req, authorization.get()) # type: ignore [Expected 2 positional arguments]
async def handle_UnlockPath(ctx: wire.Context, msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest
from trezor.ui.layouts import confirm_action
from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed
from apps.common.writers import write_uint32_le
_KEYCHAIN_MAC_KEY_PATH = [b"TREZOR", b"Keychain MAC key"]
# UnlockPath is relevant only for SLIP-25 paths.
# Note: Currently we only allow unlocking the entire SLIP-25 purpose subtree instead of
# per-coin or per-account unlocking in order to avoid UI complexity.
if msg.address_n != [SLIP25_PURPOSE]:
raise wire.DataError("Invalid path")
seed = await get_seed(ctx)
node = Slip21Node(seed)
node.derive_path(_KEYCHAIN_MAC_KEY_PATH)
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
for i in msg.address_n:
write_uint32_le(mac, i)
expected_mac = mac.get_digest()
# Require confirmation to access SLIP25 paths unless already authorized.
if msg.mac:
if len(msg.mac) != len(expected_mac) or not utils.consteq(
expected_mac, msg.mac
):
raise wire.DataError("Invalid MAC")
else:
await confirm_action(
ctx,
"confirm_coinjoin_access",
title="CoinJoin account",
description="Do you want to allow access to CoinJoin accounts?",
)
wire_types = (MessageType.GetAddress, MessageType.GetPublicKey, MessageType.SignTx)
req = await ctx.call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
)
assert handler is not None
return await handler(ctx, req, msg) # type: ignore [Expected 2 positional arguments]
async def handle_CancelAuthorization(
ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType:
@ -336,6 +385,7 @@ def boot() -> None:
workflow_handlers.register(MessageType.EndSession, handle_EndSession)
workflow_handlers.register(MessageType.Ping, handle_Ping)
workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
workflow_handlers.register(MessageType.UnlockPath, handle_UnlockPath)
workflow_handlers.register(
MessageType.CancelAuthorization, handle_CancelAuthorization
)

@ -8,7 +8,8 @@ from trezor.strings import format_amount
from trezor.ui.layouts import confirm_action, confirm_coinjoin, confirm_metadata
from apps.common import authorization, safety_checks
from apps.common.paths import validate_path
from apps.common.keychain import FORBIDDEN_KEY_PATH
from apps.common.paths import SLIP25_PURPOSE, validate_path
from .authorization import FEE_RATE_DECIMALS
from .common import BIP32_WALLET_DEPTH
@ -47,6 +48,9 @@ async def authorize_coinjoin(
if not msg.address_n:
raise wire.DataError("Empty path not allowed.")
if msg.address_n[0] != SLIP25_PURPOSE and safety_checks.is_strict():
raise FORBIDDEN_KEY_PATH
await confirm_action(
ctx,
"coinjoin_coordinator",
@ -57,19 +61,19 @@ async def authorize_coinjoin(
icon=ui.ICON_RECOVERY,
)
max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3)
await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte)
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path(
ctx,
keychain,
validation_path,
msg.address_n[0] == SLIP25_PURPOSE,
validate_path_against_script_type(
coin, address_n=validation_path, script_type=msg.script_type
),
)
max_fee_per_vbyte = format_amount(msg.max_fee_per_kvbyte, 3)
if msg.max_fee_per_kvbyte > coin.maxfee_kb:
await confirm_metadata(
ctx,
@ -80,6 +84,8 @@ async def authorize_coinjoin(
ButtonRequestType.FeeOverThreshold,
)
await confirm_coinjoin(ctx, coin.coin_name, msg.max_rounds, max_fee_per_vbyte)
authorization.set(msg)
return Success(message="CoinJoin authorized")

@ -2,21 +2,33 @@ from typing import TYPE_CHECKING
from trezor import wire
from trezor.enums import InputScriptType
from trezor.messages import HDNodeType, PublicKey
from trezor.messages import HDNodeType, PublicKey, UnlockPath
from apps.common import coininfo, paths
from apps.common.keychain import get_keychain
from apps.common.keychain import FORBIDDEN_KEY_PATH, get_keychain
if TYPE_CHECKING:
from trezor.messages import GetPublicKey
from trezor.protobuf import MessageType
async def get_public_key(ctx: wire.Context, msg: GetPublicKey) -> PublicKey:
async def get_public_key(
ctx: wire.Context, msg: GetPublicKey, auth_msg: MessageType | None = None
) -> PublicKey:
coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or InputScriptType.SPENDADDRESS
coin = coininfo.by_name(coin_name)
curve_name = msg.ecdsa_curve_name or coin.curve_name
if msg.address_n and msg.address_n[0] == paths.SLIP25_PURPOSE:
# UnlockPath is required to access SLIP25 paths.
if not UnlockPath.is_type_of(auth_msg):
raise FORBIDDEN_KEY_PATH
# Verify that the desired path lies in the unlocked subtree.
if auth_msg.address_n != msg.address_n[: len(auth_msg.address_n)]:
raise FORBIDDEN_KEY_PATH
keychain = await get_keychain(ctx, curve_name, [paths.AlwaysMatchingSchema])
node = keychain.derive(msg.address_n)

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
from trezor import wire
from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx, UnlockPath
from apps.common import coininfo
from apps.common.keychain import get_keychain
@ -19,13 +20,10 @@ if TYPE_CHECKING:
from trezor.protobuf import MessageType
from trezor.messages import (
AuthorizeCoinJoin,
GetAddress,
GetOwnershipId,
GetOwnershipProof,
GetPublicKey,
SignMessage,
SignTx,
VerifyMessage,
)
@ -66,6 +64,10 @@ PATTERN_BIP49 = "m/49'/coin_type'/account'/change/address_index"
PATTERN_BIP84 = "m/84'/coin_type'/account'/change/address_index"
# BIP-86 for taproot: https://github.com/bitcoin/bips/blob/master/bip-0086.mediawiki
PATTERN_BIP86 = "m/86'/coin_type'/account'/change/address_index"
# SLIP-25 for CoinJoin: https://github.com/satoshilabs/slips/blob/master/slip-0025.md
# Only account=0 and script_type=1 are supported for now.
PATTERN_SLIP25_TAPROOT = "m/10025'/coin_type'/0'/1'/change/address_index"
PATTERN_SLIP25_TAPROOT_EXTERNAL = "m/10025'/coin_type'/0'/1'/0/address_index"
# compatibility patterns, will be removed in the future
PATTERN_GREENADDRESS_A = "m/[1,4]/address_index"
@ -151,13 +153,16 @@ def validate_path_against_script_type(
elif coin.taproot and script_type == InputScriptType.SPENDTAPROOT:
patterns.append(PATTERN_BIP86)
patterns.append(PATTERN_SLIP25_TAPROOT)
return any(
PathSchema.parse(pattern, coin.slip44).match(address_n) for pattern in patterns
)
def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
def get_schemas_for_coin(
coin: coininfo.CoinInfo, unlock_schemas: Iterable[PathSchema] = ()
) -> Iterable[PathSchema]:
# basic patterns
patterns = [
PATTERN_BIP44,
@ -206,6 +211,16 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
if coin.taproot:
patterns.append(PATTERN_BIP86)
schemas = get_schemas_from_patterns(patterns, coin)
schemas.extend(unlock_schemas)
gc.collect()
return [schema.copy() for schema in schemas]
def get_schemas_from_patterns(
patterns: Iterable[str], coin: coininfo.CoinInfo
) -> list[PathSchema]:
schemas = [PathSchema.parse(pattern, coin.slip44) for pattern in patterns]
# Some wallets such as Electron-Cash (BCH) store coins on Bitcoin paths.
@ -219,8 +234,7 @@ def get_schemas_for_coin(coin: coininfo.CoinInfo) -> Iterable[PathSchema]:
PathSchema.parse(pattern, SLIP44_BITCOIN) for pattern in patterns
)
gc.collect()
return [schema.copy() for schema in schemas]
return schemas
def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
@ -234,13 +248,52 @@ def get_coin_by_name(coin_name: str | None) -> coininfo.CoinInfo:
async def get_keychain_for_coin(
ctx: wire.Context, coin_name: str | None
) -> tuple[Keychain, coininfo.CoinInfo]:
coin = get_coin_by_name(coin_name)
schemas = get_schemas_for_coin(coin)
ctx: wire.Context,
coin: coininfo.CoinInfo,
unlock_schemas: Iterable[PathSchema] = (),
) -> Keychain:
schemas = get_schemas_for_coin(coin, unlock_schemas)
slip21_namespaces = [[b"SLIP-0019"], [b"SLIP-0024"]]
keychain = await get_keychain(ctx, coin.curve_name, schemas, slip21_namespaces)
return keychain, coin
return keychain
def _get_unlock_schemas(
msg: MessageType, auth_msg: MessageType | None, coin: coininfo.CoinInfo
) -> list[PathSchema]:
"""
Provides additional keychain schemas that are unlocked by the particular
combination of `msg` and `auth_msg`.
"""
if AuthorizeCoinJoin.is_type_of(msg):
# When processing the AuthorizeCoinJoin message, validate_path() always
# needs to treat SLIP-25 paths as valid, so add SLIP-25 to the schemas.
return get_schemas_from_patterns([PATTERN_SLIP25_TAPROOT], coin)
if AuthorizeCoinJoin.is_type_of(auth_msg) or UnlockPath.is_type_of(auth_msg):
# The user has preauthorized access to certain paths. Here we create a
# list of all the patterns that can be unlocked by AuthorizeCoinJoin or
# by UnlockPath. At the moment only SLIP-25 paths can be unlocked.
patterns = []
if SignTx.is_type_of(msg) or GetOwnershipProof.is_type_of(msg):
# SignTx and GetOwnershipProof need access to all SLIP-25 addresses
# to create CoinJoin outputs.
patterns.append(PATTERN_SLIP25_TAPROOT)
else:
# In case of other messages like GetAddress or SignMessage there is
# no reason for the user to work with SLIP-25 change-addresses. For
# example, using a change-address to receive a payment may
# compromise privacy.
patterns.append(PATTERN_SLIP25_TAPROOT_EXTERNAL)
# Convert the unlockable patterns to schemas and select only the ones
# that are unlocked by the auth_msg, i.e. lie in a subtree of the
# auth_msg's path.
schemas = get_schemas_from_patterns(patterns, coin)
return [s for s in schemas if s.restrict(auth_msg.address_n)]
return []
def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
@ -249,8 +302,10 @@ def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
msg: MsgIn,
auth_msg: MessageType | None = None,
) -> MsgOut:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
if auth_msg:
coin = get_coin_by_name(msg.coin_name)
unlock_schemas = _get_unlock_schemas(msg, auth_msg, coin)
keychain = await get_keychain_for_coin(ctx, coin, unlock_schemas)
if AuthorizeCoinJoin.is_type_of(auth_msg):
auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj)
else:

@ -2,6 +2,7 @@ from micropython import const
from typing import TYPE_CHECKING
HARDENED = const(0x8000_0000)
SLIP25_PURPOSE = const(10025 | HARDENED)
if TYPE_CHECKING:
from typing import (
@ -246,6 +247,41 @@ class PathSchema:
return True
def set_never_matching(self) -> None:
"""Sets the schema to never match any paths."""
self.schema = []
self.trailing_components = self._EMPTY_TUPLE
def restrict(self, path: Bip32Path) -> bool:
"""
Restricts the schema to patterns that are prefixed by the specified
path. If the restriction results in a never-matching schema, then False
is returned.
"""
for i, value in enumerate(path):
if i < len(self.schema):
# Ensure that the path is a prefix of the schema.
if value not in self.schema[i]:
self.set_never_matching()
return False
# Restrict the schema component if there are multiple choices.
component = self.schema[i]
if not isinstance(component, tuple) or len(component) != 1:
self.schema[i] = (value,)
else:
# The path is longer than the schema. We need to restrict the
# trailing components.
if value not in self.trailing_components:
self.set_never_matching()
return False
self.schema.append((value,))
return True
if __debug__:
def __repr__(self) -> str:

@ -4,7 +4,7 @@ from trezor import wire
from trezor.crypto import bip39
from apps.common.paths import HARDENED
from apps.bitcoin.keychain import get_keychain_for_coin
from apps.bitcoin.keychain import get_coin_by_name, get_keychain_for_coin
class TestBitcoinKeychain(unittest.TestCase):
@ -14,9 +14,8 @@ class TestBitcoinKeychain(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed)
def test_bitcoin(self):
keychain, coin = await_result(
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bitcoin")
)
coin = get_coin_by_name("Bitcoin")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bitcoin")
valid_addresses = (
@ -46,9 +45,8 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr)
def test_testnet(self):
keychain, coin = await_result(
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Testnet")
)
coin = get_coin_by_name("Testnet")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Testnet")
valid_addresses = (
@ -78,13 +76,14 @@ class TestBitcoinKeychain(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr)
def test_unspecified(self):
keychain, coin = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, None))
coin = get_coin_by_name(None)
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bitcoin")
keychain.derive([H_(44), H_(0), H_(0), 0, 0])
def test_unknown(self):
with self.assertRaises(wire.DataError):
await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, "MadeUpCoin2020"))
get_coin_by_name("MadeUpCoin2020")
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
@ -95,9 +94,8 @@ class TestAltcoinKeychains(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed)
def test_bcash(self):
keychain, coin = await_result(
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Bcash")
)
coin = get_coin_by_name("Bcash")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Bcash")
self.assertFalse(coin.segwit)
@ -133,9 +131,8 @@ class TestAltcoinKeychains(unittest.TestCase):
self.assertRaises(wire.DataError, keychain.derive, addr)
def test_litecoin(self):
keychain, coin = await_result(
get_keychain_for_coin(wire.DUMMY_CONTEXT, "Litecoin")
)
coin = get_coin_by_name("Litecoin")
keychain = await_result(get_keychain_for_coin(wire.DUMMY_CONTEXT, coin))
self.assertEqual(coin.coin_name, "Litecoin")
self.assertTrue(coin.segwit)

@ -174,6 +174,31 @@ class TestPathSchemas(unittest.TestCase):
)
self.assertEqualSchema(schema_manual, schema_parsed)
def test_restrict(self):
PATTERN_BIP44 = "m/44'/coin_type'/account'/change/address_index"
# Restrict coin type to Bitcoin.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertTrue(schema.restrict([H_(44), H_(0)]))
expected = PathSchema.parse(PATTERN_BIP44, 0)
self.assertEqualSchema(schema, expected)
# Restrict coin type to Bitcoin Cash and account 2.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertTrue(schema.restrict([H_(44), H_(145), H_(2)]))
expected = PathSchema.parse("m/44'/145'/2'/change/address_index", 0)
self.assertEqualSchema(schema, expected)
# Restrict wildcards.
schema = PathSchema.parse("m/10018'/**", 0)
self.assertTrue(schema.restrict([H_(10018), H_(3), 7]))
expected = PathSchema.parse("m/10018'/3'/7/**", 0)
self.assertEqualSchema(schema, expected)
# Restrict to a never-matching schema.
schema = PathSchema.parse(PATTERN_BIP44, (0, 145))
self.assertFalse(schema.restrict([H_(44), H_(0), 0]))
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save