1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-14 17:31:04 +00:00

fix(core/bitcoin): fix CoinJoin authorization with new cache

This commit is contained in:
matejcik 2021-03-30 10:57:07 +02:00 committed by matejcik
parent dd655422f1
commit aaa3ce6117
12 changed files with 165 additions and 135 deletions

View File

@ -553,7 +553,7 @@ message AuthorizeCoinJoin {
required string coordinator = 1; // coordinator identifier to approve as a prefix in commitment data (max. 18 ASCII characters) required string coordinator = 1; // coordinator identifier to approve as a prefix in commitment data (max. 18 ASCII characters)
required uint64 max_total_fee = 2; // maximum total fees required uint64 max_total_fee = 2; // maximum total fees
optional uint32 fee_per_anonymity = 3; // fee per anonymity set in units of 10^-9 percent optional uint32 fee_per_anonymity = 3 [default=0]; // fee per anonymity set in units of 10^-9 percent
repeated uint32 address_n = 4; // prefix of the BIP-32 path leading to the account (m / purpose' / coin_type' / account') repeated uint32 address_n = 4; // prefix of the BIP-32 path leading to the account (m / purpose' / coin_type' / account')
optional string coin_name = 5 [default='Bitcoin']; // coin to use optional string coin_name = 5 [default='Bitcoin']; // coin to use
optional InputScriptType script_type = 6 [default=SPENDADDRESS]; // used to distinguish between various address formats (non-segwit, segwit, etc.) optional InputScriptType script_type = 6 [default=SPENDADDRESS]; // used to distinguish between various address formats (non-segwit, segwit, etc.)

View File

@ -8,7 +8,7 @@ from . import workflow_handlers
if False: if False:
import protobuf import protobuf
from typing import Iterable, NoReturn, Protocol from typing import NoReturn
from trezor.messages.Features import Features from trezor.messages.Features import Features
from trezor.messages.Initialize import Initialize from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession from trezor.messages.EndSession import EndSession
@ -19,15 +19,6 @@ if False:
from trezor.messages.DoPreauthorized import DoPreauthorized from trezor.messages.DoPreauthorized import DoPreauthorized
from trezor.messages.CancelAuthorization import CancelAuthorization from trezor.messages.CancelAuthorization import CancelAuthorization
if False:
class Authorization(Protocol):
def expected_wire_types(self) -> Iterable[int]:
...
def __del__(self) -> None:
...
def get_features() -> Features: def get_features() -> Features:
import storage.recovery import storage.recovery
@ -144,16 +135,15 @@ async def handle_DoPreauthorized(
ctx: wire.Context, msg: DoPreauthorized ctx: wire.Context, msg: DoPreauthorized
) -> protobuf.MessageType: ) -> protobuf.MessageType:
from trezor.messages.PreauthorizedRequest import PreauthorizedRequest from trezor.messages.PreauthorizedRequest import PreauthorizedRequest
from apps.common import authorization
authorization: Authorization = storage.cache.get( if not authorization.is_set():
storage.cache.APP_BASE_AUTHORIZATION
)
if not authorization:
raise wire.ProcessError("No preauthorized operation") raise wire.ProcessError("No preauthorized operation")
req = await ctx.call_any( wire_types = authorization.get_wire_types()
PreauthorizedRequest(), *authorization.expected_wire_types() utils.ensure(bool(wire_types), "Unsupported preauthorization found")
)
req = await ctx.call_any(PreauthorizedRequest(), *wire_types)
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE ctx.iface, req.MESSAGE_WIRE_TYPE
@ -161,28 +151,15 @@ async def handle_DoPreauthorized(
if handler is None: if handler is None:
return wire.unexpected_message() return wire.unexpected_message()
return await handler(ctx, req, authorization) # type: ignore return await handler(ctx, req, authorization.get()) # type: ignore
def set_authorization(authorization: Authorization) -> None:
previous: Authorization = storage.cache.get(storage.cache.APP_BASE_AUTHORIZATION)
if previous:
previous.__del__()
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, authorization)
async def handle_CancelAuthorization( async def handle_CancelAuthorization(
ctx: wire.Context, msg: CancelAuthorization ctx: wire.Context, msg: CancelAuthorization
) -> protobuf.MessageType: ) -> protobuf.MessageType:
authorization: Authorization = storage.cache.get( from apps.common import authorization
storage.cache.APP_BASE_AUTHORIZATION
)
if not authorization:
raise wire.ProcessError("No preauthorized operation")
authorization.__del__()
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"")
authorization.clear()
return Success(message="Authorization cancelled") return Success(message="Authorization cancelled")

View File

@ -1,64 +1,59 @@
from micropython import const from micropython import const
from trezor.messages import MessageType from trezor import wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from apps.common import authorization
from .common import BIP32_WALLET_DEPTH from .common import BIP32_WALLET_DEPTH
if False: if False:
from typing import Iterable import protobuf
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.GetOwnershipProof import GetOwnershipProof from trezor.messages.GetOwnershipProof import GetOwnershipProof
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TxInput import TxInput from trezor.messages.TxInput import TxInput
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
_ROUND_ID_LEN = const(32) _ROUND_ID_LEN = const(32)
FEE_PER_ANONYMITY_DECIMALS = const(9) FEE_PER_ANONYMITY_DECIMALS = const(9)
class CoinJoinAuthorization: class CoinJoinAuthorization:
def __init__( def __init__(self, params: AuthorizeCoinJoin) -> None:
self, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo self.params = params
) -> None:
self.coordinator = msg.coordinator
self.remaining_fee = msg.max_total_fee
self.fee_per_anonymity = msg.fee_per_anonymity or 0
self.address_n = msg.address_n
self.keychain = keychain
self.coin = coin
self.script_type = msg.script_type
def __del__(self) -> None:
self.keychain.__del__()
def expected_wire_types(self) -> Iterable[int]:
return (MessageType.SignTx, MessageType.GetOwnershipProof)
def check_get_ownership_proof(self, msg: GetOwnershipProof) -> bool: def check_get_ownership_proof(self, msg: GetOwnershipProof) -> bool:
# Check whether the current authorization matches the parameters of the request. # Check whether the current params matches the parameters of the request.
return ( return (
len(msg.address_n) >= BIP32_WALLET_DEPTH len(msg.address_n) >= BIP32_WALLET_DEPTH
and msg.address_n[:-BIP32_WALLET_DEPTH] == self.address_n and msg.address_n[:-BIP32_WALLET_DEPTH] == self.params.address_n
and msg.coin_name == self.coin.coin_name and msg.coin_name == self.params.coin_name
and msg.script_type == self.script_type and msg.script_type == self.params.script_type
and len(msg.commitment_data) >= _ROUND_ID_LEN and len(msg.commitment_data) >= _ROUND_ID_LEN
and msg.commitment_data[:-_ROUND_ID_LEN] == self.coordinator.encode() and msg.commitment_data[:-_ROUND_ID_LEN] == self.params.coordinator.encode()
) )
def check_sign_tx_input(self, txi: TxInput, coin: CoinInfo) -> bool: def check_sign_tx_input(self, txi: TxInput, coin: CoinInfo) -> bool:
# Check whether the current input matches the parameters of the request. # Check whether the current input matches the parameters of the request.
return ( return (
len(txi.address_n) >= BIP32_WALLET_DEPTH len(txi.address_n) >= BIP32_WALLET_DEPTH
and txi.address_n[:-BIP32_WALLET_DEPTH] == self.address_n and txi.address_n[:-BIP32_WALLET_DEPTH] == self.params.address_n
and coin.coin_name == self.coin.coin_name and coin.coin_name == self.params.coin_name
and txi.script_type == self.script_type and txi.script_type == self.params.script_type
) )
def approve_sign_tx(self, msg: SignTx, fee: int) -> bool: def approve_sign_tx(self, msg: SignTx, fee: int) -> bool:
if self.remaining_fee < fee or msg.coin_name != self.coin.coin_name: if self.params.max_total_fee < fee or msg.coin_name != self.params.coin_name:
return False return False
self.remaining_fee -= fee self.params.max_total_fee -= fee
authorization.set(self.params)
return True return True
def from_cached_message(auth_msg: protobuf.MessageType) -> CoinJoinAuthorization:
if not isinstance(auth_msg, AuthorizeCoinJoin):
raise wire.ProcessError("Appropriate params was not found")
return CoinJoinAuthorization(auth_msg)

View File

@ -1,32 +1,30 @@
from micropython import const from micropython import const
from trezor import ui from trezor import ui, wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.strings import format_amount from trezor.strings import format_amount
from trezor.ui.layouts import confirm_action, confirm_coinjoin from trezor.ui.layouts import confirm_action, confirm_coinjoin
from apps.base import set_authorization from apps.common import authorization
from apps.common.paths import validate_path from apps.common.paths import validate_path
from .authorization import FEE_PER_ANONYMITY_DECIMALS, CoinJoinAuthorization from .authorization import FEE_PER_ANONYMITY_DECIMALS
from .common import BIP32_WALLET_DEPTH from .common import BIP32_WALLET_DEPTH
from .keychain import get_keychain_for_coin, validate_path_against_script_type from .keychain import validate_path_against_script_type, with_keychain
from .sign_tx.layout import format_coin_amount from .sign_tx.layout import format_coin_amount
if False: if False:
from trezor import wire from apps.common.coininfo import CoinInfo
from apps.common.keychain import Keychain
_MAX_COORDINATOR_LEN = const(18) _MAX_COORDINATOR_LEN = const(18)
async def authorize_coinjoin(ctx: wire.Context, msg: AuthorizeCoinJoin) -> Success: @with_keychain
# We cannot use the @with_keychain decorator here, because we need the keychain async def authorize_coinjoin(
# to survive the function exit. The ownership of the keychain is transferred to ctx: wire.Context, msg: AuthorizeCoinJoin, keychain: Keychain, coin: CoinInfo
# the CoinJoinAuthorization object, which takes care of its destruction. ) -> Success:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
try:
if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all( if len(msg.coordinator) > _MAX_COORDINATOR_LEN or not all(
32 <= ord(x) <= 126 for x in msg.coordinator 32 <= ord(x) <= 126 for x in msg.coordinator
): ):
@ -55,21 +53,19 @@ async def authorize_coinjoin(ctx: wire.Context, msg: AuthorizeCoinJoin) -> Succe
icon=ui.ICON_RECOVERY, icon=ui.ICON_RECOVERY,
) )
fee_per_anonymity = None if msg.fee_per_anonymity:
if msg.fee_per_anonymity is not None: fee_per_anonymity: str | None = format_amount(
fee_per_anonymity = format_amount(
msg.fee_per_anonymity, FEE_PER_ANONYMITY_DECIMALS msg.fee_per_anonymity, FEE_PER_ANONYMITY_DECIMALS
) )
else:
fee_per_anonymity = None
await confirm_coinjoin( await confirm_coinjoin(
ctx, ctx,
fee_per_anonymity, fee_per_anonymity,
format_coin_amount(msg.max_total_fee, coin, msg.amount_unit), format_coin_amount(msg.max_total_fee, coin, msg.amount_unit),
) )
set_authorization(CoinJoinAuthorization(msg, keychain, coin)) authorization.set(msg)
except BaseException:
keychain.__del__()
raise
return Success(message="CoinJoin authorized") return Success(message="CoinJoin authorized")

View File

@ -7,19 +7,20 @@ from apps.common import coininfo
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import PATTERN_BIP44, PathSchema from apps.common.paths import PATTERN_BIP44, PathSchema
from . import authorization
from .common import BITCOIN_NAMES from .common import BITCOIN_NAMES
if False: if False:
from typing import Awaitable, Callable, Iterable, TypeVar from typing import Awaitable, Callable, Iterable, TypeVar
from typing_extensions import Protocol from typing_extensions import Protocol
from protobuf import MessageType
from trezor.messages.TxInputType import EnumTypeInputScriptType from trezor.messages.TxInputType import EnumTypeInputScriptType
from apps.common.keychain import Keychain, MsgOut, Handler from apps.common.keychain import Keychain, MsgOut, Handler
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
from .authorization import CoinJoinAuthorization
class MsgWithCoinName(Protocol): class MsgWithCoinName(Protocol):
coin_name: str coin_name: str
@ -189,14 +190,13 @@ def with_keychain(func: HandlerWithCoinInfo[MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper( async def wrapper(
ctx: wire.Context, ctx: wire.Context,
msg: MsgIn, msg: MsgIn,
authorization: CoinJoinAuthorization | None = None, auth_msg: MessageType | None = None,
) -> MsgOut: ) -> MsgOut:
if authorization:
keychain = authorization.keychain
coin = get_coin_by_name(msg.coin_name)
return await func(ctx, msg, keychain, coin, authorization)
else:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
if auth_msg:
auth_obj = authorization.from_cached_message(auth_msg)
return await func(ctx, msg, keychain, coin, auth_obj)
else:
with keychain: with keychain:
return await func(ctx, msg, keychain, coin) return await func(ctx, msg, keychain, coin)

View File

@ -263,6 +263,9 @@ class CoinJoinApprover(Approver):
super().__init__(tx, coin) super().__init__(tx, coin)
self.authorization = authorization self.authorization = authorization
if authorization.params.coin_name != tx.coin_name:
raise wire.DataError("Coin name does not match authorization.")
# Upper bound on the user's contribution to the weight of the transaction. # Upper bound on the user's contribution to the weight of the transaction.
self.our_weight = tx_weight.TxWeightCalculator( self.our_weight = tx_weight.TxWeightCalculator(
tx.inputs_count, tx.outputs_count tx.inputs_count, tx.outputs_count
@ -352,7 +355,7 @@ class CoinJoinApprover(Approver):
decimal_divisor: float = pow(10, FEE_PER_ANONYMITY_DECIMALS + 2) decimal_divisor: float = pow(10, FEE_PER_ANONYMITY_DECIMALS + 2)
return ( return (
self.coordinator_fee_base self.coordinator_fee_base
* self.authorization.fee_per_anonymity * self.authorization.params.fee_per_anonymity
/ decimal_divisor / decimal_divisor
) )

View File

@ -0,0 +1,53 @@
import protobuf
import storage.cache
from trezor import messages, utils
from trezor.messages import MessageType
if False:
from typing import Iterable
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
def is_set() -> bool:
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE))
def set(auth_message: protobuf.MessageType) -> None:
buffer = bytearray(protobuf.count_message(auth_message))
writer = utils.BufferWriter(buffer)
protobuf.dump_message(writer, auth_message)
storage.cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
)
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
msg_wire_type = int.from_bytes(stored_auth_type, "big")
msg_type = messages.get_type(msg_wire_type)
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
reader = utils.BufferReader(buffer)
return protobuf.load_message(reader, msg_type)
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return ()
msg_wire_type = int.from_bytes(stored_auth_type, "big")
return WIRE_TYPES.get(msg_wire_type, ())
def clear() -> None:
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_TYPE, b"")
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")

View File

@ -14,7 +14,8 @@ _SESSION_ID_LENGTH = 32
APP_COMMON_SEED = 0 APP_COMMON_SEED = 0
APP_CARDANO_PASSPHRASE = 1 APP_CARDANO_PASSPHRASE = 1
APP_MONERO_LIVE_REFRESH = 2 APP_MONERO_LIVE_REFRESH = 2
APP_BASE_AUTHORIZATION = 3 APP_COMMON_AUTHORIZATION_TYPE = 3
APP_COMMON_AUTHORIZATION_DATA = 4
# Keys that are valid across sessions # Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
@ -52,7 +53,8 @@ class SessionCache(DataCache):
64, # APP_COMMON_SEED 64, # APP_COMMON_SEED
50, # APP_CARDANO_PASSPHRASE 50, # APP_CARDANO_PASSPHRASE
1, # APP_MONERO_LIVE_REFRESH 1, # APP_MONERO_LIVE_REFRESH
128, # APP_BASE_AUTHORIZATION 2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
) )
self.last_usage = 0 self.last_usage = 0
super().__init__() super().__init__()

View File

@ -23,7 +23,7 @@ class AuthorizeCoinJoin(p.MessageType):
coordinator: str, coordinator: str,
max_total_fee: int, max_total_fee: int,
address_n: Optional[List[int]] = None, address_n: Optional[List[int]] = None,
fee_per_anonymity: Optional[int] = None, fee_per_anonymity: int = 0,
coin_name: str = "Bitcoin", coin_name: str = "Bitcoin",
script_type: EnumTypeInputScriptType = 0, script_type: EnumTypeInputScriptType = 0,
amount_unit: EnumTypeAmountUnit = 0, amount_unit: EnumTypeAmountUnit = 0,
@ -41,7 +41,7 @@ class AuthorizeCoinJoin(p.MessageType):
return { return {
1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED), 1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED),
2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED), 2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED),
3: ('fee_per_anonymity', p.UVarintType, None), 3: ('fee_per_anonymity', p.UVarintType, 0), # default=0
4: ('address_n', p.UVarintType, p.FLAG_REPEATED), 4: ('address_n', p.UVarintType, p.FLAG_REPEATED),
5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin 5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin
6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS 6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS

View File

@ -1,5 +1,6 @@
from common import unittest, await_result, H_ from common import unittest, await_result, H_
import storage.cache
from trezor import wire from trezor import wire
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.TxInput import TxInput from trezor.messages.TxInput import TxInput
@ -23,11 +24,12 @@ class TestApprover(unittest.TestCase):
self.msg_auth = AuthorizeCoinJoin( self.msg_auth = AuthorizeCoinJoin(
coordinator="www.example.com", coordinator="www.example.com",
max_total_fee=40000, max_total_fee=40000,
fee_per_anonymity=self.fee_per_anonymity_percent * 10**9, fee_per_anonymity=int(self.fee_per_anonymity_percent * 10**9),
address_n=[H_(84), H_(0), H_(0)], address_n=[H_(84), H_(0), H_(0)],
coin_name=self.coin.coin_name, coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS, script_type=InputScriptType.SPENDWITNESS,
) )
storage.cache.start_session()
def test_coinjoin_lots_of_inputs(self): def test_coinjoin_lots_of_inputs(self):
denomination = 10000000 denomination = 10000000
@ -74,7 +76,7 @@ class TestApprover(unittest.TestCase):
) )
) )
coordinator_fee = self.fee_per_anonymity_percent / 100 * len(outputs) * denomination coordinator_fee = int(self.fee_per_anonymity_percent / 100 * len(outputs) * denomination)
fees = coordinator_fee + 10000 fees = coordinator_fee + 10000
total_coordinator_fee = coordinator_fee * len(outputs) total_coordinator_fee = coordinator_fee * len(outputs)
@ -103,7 +105,7 @@ class TestApprover(unittest.TestCase):
) )
) )
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0) tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization) approver = CoinJoinApprover(tx, self.coin, authorization)
signer = Bitcoin(tx, None, self.coin, approver) signer = Bitcoin(tx, None, self.coin, approver)
@ -123,7 +125,7 @@ class TestApprover(unittest.TestCase):
await_result(approver.approve_tx(TxInfo(signer, tx), [])) await_result(approver.approve_tx(TxInfo(signer, tx), []))
def test_coinjoin_input_account_depth_mismatch(self): def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0) tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization) approver = CoinJoinApprover(tx, self.coin, authorization)
@ -139,7 +141,7 @@ class TestApprover(unittest.TestCase):
await_result(approver.add_internal_input(txi)) await_result(approver.add_internal_input(txi))
def test_coinjoin_input_account_path_mismatch(self): def test_coinjoin_input_account_path_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0) tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization) approver = CoinJoinApprover(tx, self.coin, authorization)

View File

@ -1,5 +1,6 @@
from common import unittest, H_ from common import unittest, H_
import storage.cache
from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin from trezor.messages.AuthorizeCoinJoin import AuthorizeCoinJoin
from trezor.messages.GetOwnershipProof import GetOwnershipProof from trezor.messages.GetOwnershipProof import GetOwnershipProof
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
@ -19,13 +20,14 @@ class TestAuthorization(unittest.TestCase):
self.msg_auth = AuthorizeCoinJoin( self.msg_auth = AuthorizeCoinJoin(
coordinator="www.example.com", coordinator="www.example.com",
max_total_fee=40000, max_total_fee=40000,
fee_per_anonymity=0.003 * 10**9, fee_per_anonymity=int(0.003 * 10**9),
address_n=[H_(84), H_(0), H_(0)], address_n=[H_(84), H_(0), H_(0)],
coin_name=self.coin.coin_name, coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS, script_type=InputScriptType.SPENDWITNESS,
) )
self.authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin) self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session()
def test_ownership_proof_account_depth_mismatch(self): def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch. # Account depth mismatch.

View File

@ -23,7 +23,7 @@ class AuthorizeCoinJoin(p.MessageType):
coordinator: str, coordinator: str,
max_total_fee: int, max_total_fee: int,
address_n: Optional[List[int]] = None, address_n: Optional[List[int]] = None,
fee_per_anonymity: Optional[int] = None, fee_per_anonymity: int = 0,
coin_name: str = "Bitcoin", coin_name: str = "Bitcoin",
script_type: EnumTypeInputScriptType = 0, script_type: EnumTypeInputScriptType = 0,
amount_unit: EnumTypeAmountUnit = 0, amount_unit: EnumTypeAmountUnit = 0,
@ -41,7 +41,7 @@ class AuthorizeCoinJoin(p.MessageType):
return { return {
1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED), 1: ('coordinator', p.UnicodeType, p.FLAG_REQUIRED),
2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED), 2: ('max_total_fee', p.UVarintType, p.FLAG_REQUIRED),
3: ('fee_per_anonymity', p.UVarintType, None), 3: ('fee_per_anonymity', p.UVarintType, 0), # default=0
4: ('address_n', p.UVarintType, p.FLAG_REPEATED), 4: ('address_n', p.UVarintType, p.FLAG_REPEATED),
5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin 5: ('coin_name', p.UnicodeType, "Bitcoin"), # default=Bitcoin
6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS 6: ('script_type', p.EnumType("InputScriptType", (0, 1, 2, 3, 4,)), 0), # default=SPENDADDRESS