From c772de9d3c3288046f4fe96bd6638f0b86f545e2 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Thu, 23 Jul 2020 17:04:05 +0200 Subject: [PATCH] core/bitcoin: Support preauthorization in @with_keychain decorator. --- core/src/apps/bitcoin/keychain.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/core/src/apps/bitcoin/keychain.py b/core/src/apps/bitcoin/keychain.py index 338508d86..e3a2c05d7 100644 --- a/core/src/apps/bitcoin/keychain.py +++ b/core/src/apps/bitcoin/keychain.py @@ -12,6 +12,8 @@ if False: from apps.common.keychain import Keychain, MsgOut, Handler + from .authorization import CoinJoinAuthorization + class MsgWithCoinName(MessageType, Protocol): coin_name = ... # type: Optional[str] @@ -57,17 +59,20 @@ def get_namespaces_for_coin(coin: coininfo.CoinInfo): return namespaces -async def get_keychain_for_coin( - ctx: wire.Context, coin_name: Optional[str] -) -> Tuple[Keychain, coininfo.CoinInfo]: +def get_coin_by_name(coin_name: Optional[str]) -> coininfo.CoinInfo: if coin_name is None: coin_name = "Bitcoin" try: - coin = coininfo.by_name(coin_name) + return coininfo.by_name(coin_name) except ValueError: raise wire.DataError("Unsupported coin type") + +async def get_keychain_for_coin( + ctx: wire.Context, coin_name: Optional[str] +) -> Tuple[Keychain, coininfo.CoinInfo]: + coin = get_coin_by_name(coin_name) namespaces = get_namespaces_for_coin(coin) slip21_namespaces = [[b"SLIP-0019"]] keychain = await get_keychain(ctx, coin.curve_name, namespaces, slip21_namespaces) @@ -75,9 +80,18 @@ async def get_keychain_for_coin( def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: - async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: - keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) - with keychain: - return await func(ctx, msg, keychain, coin) + async def wrapper( + ctx: wire.Context, + msg: MsgIn, + authorization: Optional[CoinJoinAuthorization] = None, + ) -> MsgOut: + if authorization: + keychain = authorization.keychain + coin = get_coin_by_name(msg.coin_name) + return await func(ctx, msg, keychain, coin, authorization) + else: + keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) + with keychain: + return await func(ctx, msg, keychain, coin) return wrapper