1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-26 16:18:22 +00:00

fix(core/bitcoin): fix CoinJoin authorization with new cache

This commit is contained in:
matejcik 2021-03-30 10:57:07 +02:00 committed by matejcik
parent dd655422f1
commit aaa3ce6117
12 changed files with 165 additions and 135 deletions

View File

@ -553,7 +553,7 @@ message AuthorizeCoinJoin {
required string coordinator = 1; // coordinator identifier to approve as a prefix in commitment data (max. 18 ASCII characters)
required uint64 max_total_fee = 2; // maximum total fees
optional uint32 fee_per_anonymity = 3; // fee per anonymity set in units of 10^-9 percent
optional uint32 fee_per_anonymity = 3 [default=0]; // fee per anonymity set in units of 10^-9 percent
repeated uint32 address_n = 4; // prefix of the BIP-32 path leading to the account (m / purpose' / coin_type' / account')
optional string coin_name = 5 [default='Bitcoin']; // coin to use
optional InputScriptType script_type = 6 [default=SPENDADDRESS]; // used to distinguish between various address formats (non-segwit, segwit, etc.)

View File

@ -8,7 +8,7 @@ from . import workflow_handlers
if False:
import protobuf
from typing import Iterable, NoReturn, Protocol
from typing import NoReturn
from trezor.messages.Features import Features
from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession
@ -19,15 +19,6 @@ if False:
from trezor.messages.DoPreauthorized import DoPreauthorized
from trezor.messages.CancelAuthorization import CancelAuthorization
if False:
class Authorization(Protocol):
def expected_wire_types(self) -> Iterable[int]:
...
def __del__(self) -> None:
...
def get_features() -> Features:
import storage.recovery
@ -144,16 +135,15 @@ async def handle_DoPreauthorized(
ctx: wire.Context, msg: DoPreauthorized
) -> protobuf.MessageType:
from trezor.messages.PreauthorizedRequest import PreauthorizedRequest
from apps.common import authorization
authorization: Authorization = storage.cache.get(
storage.cache.APP_BASE_AUTHORIZATION
)
if not authorization:
if not authorization.is_set():
raise wire.ProcessError("No preauthorized operation")
req = await ctx.call_any(
PreauthorizedRequest(), *authorization.expected_wire_types()
)
wire_types = authorization.get_wire_types()
utils.ensure(bool(wire_types), "Unsupported preauthorization found")
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
@ -161,28 +151,15 @@ async def handle_DoPreauthorized(
if handler is None:
return wire.unexpected_message()
return await handler(ctx, req, authorization) # type: ignore
def set_authorization(authorization: Authorization) -> None:
previous: Authorization = storage.cache.get(storage.cache.APP_BASE_AUTHORIZATION)
if previous:
previous.__del__()
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, authorization)
return await handler(ctx, req, authorization.get()) # type: ignore
async def handle_CancelAuthorization(
ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType:
authorization: Authorization = storage.cache.get(
storage.cache.APP_BASE_AUTHORIZATION
)
if not authorization:
raise wire.ProcessError("No preauthorized operation")
authorization.__del__()
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"")
from apps.common import authorization
authorization.clear()
return Success(message="Authorization cancelled")

View File

@ -1,64 +1,59 @@
from micropython import const
from trezor.messages import MessageType
from trezor import wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from apps.common import authorization
from .common import BIP32_WALLET_DEPTH
if False:
from typing import Iterable
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
import protobuf
from trezor.messages.GetOwnershipProof import GetOwnershipProof
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInput import TxInput
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
_ROUND_ID_LEN = const(32)
FEE_PER_ANONYMITY_DECIMALS = const(9)
class CoinJoinAuthorization:
def __init__(
self, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
) -> None:
self.coordinator = msg.coordinator
self.remaining_fee = msg.max_total_fee
self.fee_per_anonymity = msg.fee_per_anonymity or 0
self.address_n = msg.address_n
self.keychain = keychain
self.coin = coin
self.script_type = msg.script_type
def __del__(self) -> None:
self.keychain.__del__()
def expected_wire_types(self) -> Iterable[int]:
return (MessageType.SignTx, MessageType.GetOwnershipProof)
def __init__(self, params: AuthorizeCoinJoin) -> None:
self.params = params
def check_get_ownership_proof(self, msg: GetOwnershipProof) -> bool:
# Check whether the current authorization matches the parameters of the request.
# Check whether the current params matches the parameters of the request.
return (
len(msg.address_n) >= BIP32_WALLET_DEPTH
and msg.address_n[:-BIP32_WALLET_DEPTH] == self.address_n
and msg.coin_name == self.coin.coin_name
and msg.script_type == self.script_type
and msg.address_n[:-BIP32_WALLET_DEPTH] == self.params.address_n
and msg.coin_name == self.params.coin_name
and msg.script_type == self.params.script_type
and len(msg.commitment_data) >= _ROUND_ID_LEN
and msg.commitment_data[:-_ROUND_ID_LEN] == self.coordinator.encode()
and msg.commitment_data[:-_ROUND_ID_LEN] == self.params.coordinator.encode()
)
def check_sign_tx_input(self, txi: TxInput, coin: CoinInfo) -> bool:
# Check whether the current input matches the parameters of the request.
return (
len(txi.address_n) >= BIP32_WALLET_DEPTH
and txi.address_n[:-BIP32_WALLET_DEPTH] == self.address_n
and coin.coin_name == self.coin.coin_name
and txi.script_type == self.script_type
and txi.address_n[:-BIP32_WALLET_DEPTH] == self.params.address_n
and coin.coin_name == self.params.coin_name
and txi.script_type == self.params.script_type
)
def approve_sign_tx(self, msg: SignTx, fee: int) -> bool:
if self.remaining_fee < fee or msg.coin_name != self.coin.coin_name:
if self.params.max_total_fee < fee or msg.coin_name != self.params.coin_name:
return False
self.remaining_fee -= fee
self.params.max_total_fee -= fee
authorization.set(self.params)
return True
def from_cached_message(auth_msg: protobuf.MessageType) -> CoinJoinAuthorization:
if not isinstance(auth_msg, AuthorizeCoinJoin):
raise wire.ProcessError("Appropriate params was not found")
return CoinJoinAuthorization(auth_msg)

View File

@ -1,75 +1,71 @@
from micropython import const
from trezor import ui
from trezor import ui, wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.Success import Success
from trezor.strings import format_amount
from trezor.ui.layouts import confirm_action, confirm_coinjoin
from apps.base import set_authorization
from apps.common import authorization
from apps.common.paths import validate_path
from .authorization import FEE_PER_ANONYMITY_DECIMALS, CoinJoinAuthorization
from .authorization import FEE_PER_ANONYMITY_DECIMALS
from .common import BIP32_WALLET_DEPTH
from .keychain import get_keychain_for_coin, validate_path_against_script_type
from .keychain import validate_path_against_script_type, with_keychain
from .sign_tx.layout import format_coin_amount
if False:
from trezor import wire
from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
_MAX_COORDINATOR_LEN = const(18)
async def authorize_coinjoin(ctx: wire.Context, msg: AuthorizeCoinJoin) -> Success:
# We cannot use the @with_keychain decorator here, because we need the keychain
# to survive the function exit. The ownership of the keychain is transferred to
# the CoinJoinAuthorization object, which takes care of its destruction.
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
@with_keychain
async def authorize_coinjoin(
ctx: wire.Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
) -> Success:
if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all(
32 <= ord(x) <= 126 for x in msg.coordinator
):
raise wire.DataError("Invalid coordinator name.")
try:
if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all(
32 <= ord(x) <= 126 for x in msg.coordinator
):
raise wire.DataError("Invalid coordinator name.")
if not msg.address_n:
raise wire.DataError("Empty path not allowed.")
if not msg.address_n:
raise wire.DataError("Empty path not allowed.")
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path(
ctx,
keychain,
validation_path,
validate_path_against_script_type(
coin, address_n=validation_path, script_type=msg.script_type
),
)
validation_path = msg.address_n + [0] * BIP32_WALLET_DEPTH
await validate_path(
ctx,
keychain,
validation_path,
validate_path_against_script_type(
coin, address_n=validation_path, script_type=msg.script_type
),
await confirm_action(
ctx,
"coinjoin_coordinator",
title="Authorize CoinJoin",
description="Do you really want to take part in a CoinJoin transaction at:\n{}",
description_param=msg.coordinator,
description_param_font=ui.MONO,
icon=ui.ICON_RECOVERY,
)
if msg.fee_per_anonymity:
fee_per_anonymity: str | None = format_amount(
msg.fee_per_anonymity, FEE_PER_ANONYMITY_DECIMALS
)
await confirm_action(
ctx,
"coinjoin_coordinator",
title="Authorize CoinJoin",
description="Do you really want to take part in a CoinJoin transaction at:\n{}",
description_param=msg.coordinator,
description_param_font=ui.MONO,
icon=ui.ICON_RECOVERY,
)
else:
fee_per_anonymity = None
if msg.fee_per_anonymity is not None:
fee_per_anonymity = format_amount(
msg.fee_per_anonymity, FEE_PER_ANONYMITY_DECIMALS
)
await confirm_coinjoin(
ctx,
fee_per_anonymity,
format_coin_amount(msg.max_total_fee, coin, msg.amount_unit),
)
set_authorization(CoinJoinAuthorization(msg, keychain, coin))
await confirm_coinjoin(
ctx,
fee_per_anonymity,
format_coin_amount(msg.max_total_fee, coin, msg.amount_unit),
)
except BaseException:
keychain.__del__()
raise
authorization.set(msg)
return Success(message="CoinJoin authorized")

View File

@ -7,19 +7,20 @@ from apps.common import coininfo
from apps.common.keychain import get_keychain
from apps.common.paths import PATTERN_BIP44, PathSchema
from . import authorization
from .common import BITCOIN_NAMES
if False:
from typing import Awaitable, Callable, Iterable, TypeVar
from typing_extensions import Protocol
from protobuf import MessageType
from trezor.messages.TxInputType import EnumTypeInputScriptType
from apps.common.keychain import Keychain, MsgOut, Handler
from apps.common.paths import Bip32Path
from .authorization import CoinJoinAuthorization
class MsgWithCoinName(Protocol):
coin_name: str
@ -189,14 +190,13 @@ def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(
ctx: wire.Context,
msg: MsgIn,
authorization: CoinJoinAuthorization | None = None,
auth_msg: MessageType | None = None,
) -> MsgOut:
if authorization:
keychain = authorization.keychain
coin = get_coin_by_name(msg.coin_name)
return await func(ctx, msg, keychain, coin, authorization)
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
if auth_msg:
auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj)
else:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
with keychain:
return await func(ctx, msg, keychain, coin)

View File

@ -263,6 +263,9 @@ class CoinJoinApprover(Approver):
super().__init__(tx, coin)
self.authorization = authorization
if authorization.params.coin_name != tx.coin_name:
raise wire.DataError("Coin name does not match authorization.")
# Upper bound on the user's contribution to the weight of the transaction.
self.our_weight = tx_weight.TxWeightCalculator(
tx.inputs_count, tx.outputs_count
@ -352,7 +355,7 @@ class CoinJoinApprover(Approver):
decimal_divisor: float = pow(10, FEE_PER_ANONYMITY_DECIMALS + 2)
return (
self.coordinator_fee_base
* self.authorization.fee_per_anonymity
* self.authorization.params.fee_per_anonymity
/ decimal_divisor
)

View File

@ -0,0 +1,53 @@
import protobuf
import storage.cache
from trezor import messages, utils
from trezor.messages import MessageType
if False:
from typing import Iterable
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
def is_set() -> bool:
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE))
def set(auth_message: protobuf.MessageType) -> None:
buffer = bytearray(protobuf.count_message(auth_message))
writer = utils.BufferWriter(buffer)
protobuf.dump_message(writer, auth_message)
storage.cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
)
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
msg_wire_type = int.from_bytes(stored_auth_type, "big")
msg_type = messages.get_type(msg_wire_type)
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
reader = utils.BufferReader(buffer)
return protobuf.load_message(reader, msg_type)
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return ()
msg_wire_type = int.from_bytes(stored_auth_type, "big")
return WIRE_TYPES.get(msg_wire_type, ())
def clear() -> None:
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_TYPE, b"")
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")

View File

@ -14,7 +14,8 @@ _SESSION_ID_LENGTH = 32
APP_COMMON_SEED = 0
APP_CARDANO_PASSPHRASE = 1
APP_MONERO_LIVE_REFRESH = 2
APP_BASE_AUTHORIZATION = 3
APP_COMMON_AUTHORIZATION_TYPE = 3
APP_COMMON_AUTHORIZATION_DATA = 4
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
@ -52,7 +53,8 @@ class SessionCache(DataCache):
64, # APP_COMMON_SEED
50, # APP_CARDANO_PASSPHRASE
1, # APP_MONERO_LIVE_REFRESH
128, # APP_BASE_AUTHORIZATION
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
)
self.last_usage = 0
super().__init__()

View File

@ -23,7 +23,7 @@ class AuthorizeCoinJoin(p.MessageType):
coordinator: str,
max_total_fee: int,
address_n: Optional[List[int]] = None,
fee_per_anonymity: Optional[int] = None,
fee_per_anonymity: int = 0,
coin_name: str = "Bitcoin",
script_type: EnumTypeInputScriptType = 0,
amount_unit: EnumTypeAmountUnit = 0,
@ -41,7 +41,7 @@ class AuthorizeCoinJoin(p.MessageType):
return {
1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED),
2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED),
3: ('fee_per_anonymity', p.UVarintType, None),
3: ('fee_per_anonymity', p.UVarintType, 0), # default=0
4: ('address_n', p.UVarintType, p.FLAG_REPEATED),
5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin
6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS

View File

@ -1,5 +1,6 @@
from common import unittest, await_result, H_
import storage.cache
from trezor import wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.TxInput import TxInput
@ -23,11 +24,12 @@ class TestApprover(unittest.TestCase):
self.msg_auth = AuthorizeCoinJoin(
coordinator="www.example.com",
max_total_fee=40000,
fee_per_anonymity=self.fee_per_anonymity_percent * 10**9,
fee_per_anonymity=int(self.fee_per_anonymity_percent * 10**9),
address_n=[H_(84), H_(0), H_(0)],
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS,
)
storage.cache.start_session()
def test_coinjoin_lots_of_inputs(self):
denomination = 10000000
@ -74,7 +76,7 @@ class TestApprover(unittest.TestCase):
)
)
coordinator_fee = self.fee_per_anonymity_percent / 100 * len(outputs) * denomination
coordinator_fee = int(self.fee_per_anonymity_percent / 100 * len(outputs) * denomination)
fees = coordinator_fee + 10000
total_coordinator_fee = coordinator_fee * len(outputs)
@ -103,7 +105,7 @@ class TestApprover(unittest.TestCase):
)
)
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
signer = Bitcoin(tx, None, self.coin, approver)
@ -123,7 +125,7 @@ class TestApprover(unittest.TestCase):
await_result(approver.approve_tx(TxInfo(signer, tx), []))
def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
@ -139,7 +141,7 @@ class TestApprover(unittest.TestCase):
await_result(approver.add_internal_input(txi))
def test_coinjoin_input_account_path_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)

View File

@ -1,5 +1,6 @@
from common import unittest, H_
import storage.cache
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.GetOwnershipProof import GetOwnershipProof
from trezor.messages.SignTx import SignTx
@ -19,13 +20,14 @@ class TestAuthorization(unittest.TestCase):
self.msg_auth = AuthorizeCoinJoin(
coordinator="www.example.com",
max_total_fee=40000,
fee_per_anonymity=0.003 * 10**9,
fee_per_anonymity=int(0.003 * 10**9),
address_n=[H_(84), H_(0), H_(0)],
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS,
)
self.authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)
self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session()
def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch.

View File

@ -23,7 +23,7 @@ class AuthorizeCoinJoin(p.MessageType):
coordinator: str,
max_total_fee: int,
address_n: Optional[List[int]] = None,
fee_per_anonymity: Optional[int] = None,
fee_per_anonymity: int = 0,
coin_name: str = "Bitcoin",
script_type: EnumTypeInputScriptType = 0,
amount_unit: EnumTypeAmountUnit = 0,
@ -41,7 +41,7 @@ class AuthorizeCoinJoin(p.MessageType):
return {
1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED),
2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED),
3: ('fee_per_anonymity', p.UVarintType, None),
3: ('fee_per_anonymity', p.UVarintType, 0), # default=0
4: ('address_n', p.UVarintType, p.FLAG_REPEATED),
5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin
6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS