mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 23:48:12 +00:00
core/bitcoin: Support preauthorization in @with_keychain decorator.
This commit is contained in:
parent
208283e13e
commit
c772de9d3c
@ -12,6 +12,8 @@ if False:
|
||||
|
||||
from apps.common.keychain import Keychain, MsgOut, Handler
|
||||
|
||||
from .authorization import CoinJoinAuthorization
|
||||
|
||||
class MsgWithCoinName(MessageType, Protocol):
|
||||
coin_name = ... # type: Optional[str]
|
||||
|
||||
@ -57,17 +59,20 @@ def get_namespaces_for_coin(coin: coininfo.CoinInfo):
|
||||
return namespaces
|
||||
|
||||
|
||||
async def get_keychain_for_coin(
|
||||
ctx: wire.Context, coin_name: Optional[str]
|
||||
) -> Tuple[Keychain, coininfo.CoinInfo]:
|
||||
def get_coin_by_name(coin_name: Optional[str]) -> coininfo.CoinInfo:
|
||||
if coin_name is None:
|
||||
coin_name = "Bitcoin"
|
||||
|
||||
try:
|
||||
coin = coininfo.by_name(coin_name)
|
||||
return coininfo.by_name(coin_name)
|
||||
except ValueError:
|
||||
raise wire.DataError("Unsupported coin type")
|
||||
|
||||
|
||||
async def get_keychain_for_coin(
|
||||
ctx: wire.Context, coin_name: Optional[str]
|
||||
) -> Tuple[Keychain, coininfo.CoinInfo]:
|
||||
coin = get_coin_by_name(coin_name)
|
||||
namespaces = get_namespaces_for_coin(coin)
|
||||
slip21_namespaces = [[b"SLIP-0019"]]
|
||||
keychain = await get_keychain(ctx, coin.curve_name, namespaces, slip21_namespaces)
|
||||
@ -75,7 +80,16 @@ async def get_keychain_for_coin(
|
||||
|
||||
|
||||
def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
async def wrapper(
|
||||
ctx: wire.Context,
|
||||
msg: MsgIn,
|
||||
authorization: Optional[CoinJoinAuthorization] = None,
|
||||
) -> MsgOut:
|
||||
if authorization:
|
||||
keychain = authorization.keychain
|
||||
coin = get_coin_by_name(msg.coin_name)
|
||||
return await func(ctx, msg, keychain, coin, authorization)
|
||||
else:
|
||||
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain, coin)
|
||||
|
Loading…
Reference in New Issue
Block a user