feat(core/cardano): allow derivation type selection

pull/1882/head
matejcik 3 years ago committed by matejcik
parent f6f52445bd
commit 10e338e005

@ -7,6 +7,12 @@ option java_outer_classname = "TrezorMessageCardano";
import "messages-common.proto"; import "messages-common.proto";
enum CardanoDerivationType {
LEDGER = 0;
ICARUS = 1;
ICARUS_TREZOR = 2;
}
/** /**
* Values correspond to address header values given by the spec. * Values correspond to address header values given by the spec.
* Script addresses are only supported in transaction outputs. * Script addresses are only supported in transaction outputs.
@ -103,6 +109,7 @@ message CardanoNativeScript {
message CardanoGetNativeScriptHash { message CardanoGetNativeScriptHash {
required CardanoNativeScript script = 1; required CardanoNativeScript script = 1;
required CardanoNativeScriptHashDisplayFormat display_format = 2; // display hash as bech32 or policy id required CardanoNativeScriptHashDisplayFormat display_format = 2; // display hash as bech32 or policy id
required CardanoDerivationType derivation_type = 3;
} }
/** /**
@ -145,6 +152,7 @@ message CardanoGetAddress {
required uint32 protocol_magic = 3; // network's protocol magic - needed for Byron addresses on testnets required uint32 protocol_magic = 3; // network's protocol magic - needed for Byron addresses on testnets
required uint32 network_id = 4; // network id - mainnet or testnet required uint32 network_id = 4; // network id - mainnet or testnet
required CardanoAddressParametersType address_parameters = 5; // parameters used to derive the address required CardanoAddressParametersType address_parameters = 5; // parameters used to derive the address
required CardanoDerivationType derivation_type = 6;
} }
/** /**
@ -164,6 +172,7 @@ message CardanoAddress {
message CardanoGetPublicKey { message CardanoGetPublicKey {
repeated uint32 address_n = 1; // BIP-32 path to derive the key from master node repeated uint32 address_n = 1; // BIP-32 path to derive the key from master node
optional bool show_display = 2; // optionally show on display before sending the result optional bool show_display = 2; // optionally show on display before sending the result
required CardanoDerivationType derivation_type = 3;
} }
/** /**
@ -195,6 +204,7 @@ message CardanoSignTxInit {
optional uint64 validity_interval_start = 11; optional uint64 validity_interval_start = 11;
required uint32 witness_requests_count = 12; required uint32 witness_requests_count = 12;
required uint32 minting_asset_groups_count = 13; required uint32 minting_asset_groups_count = 13;
required CardanoDerivationType derivation_type = 14;
} }
/** /**

@ -388,6 +388,8 @@ if not utils.BITCOIN_ONLY:
import trezor.enums.CardanoAddressType import trezor.enums.CardanoAddressType
trezor.enums.CardanoCertificateType trezor.enums.CardanoCertificateType
import trezor.enums.CardanoCertificateType import trezor.enums.CardanoCertificateType
trezor.enums.CardanoDerivationType
import trezor.enums.CardanoDerivationType
trezor.enums.CardanoNativeScriptHashDisplayFormat trezor.enums.CardanoNativeScriptHashDisplayFormat
import trezor.enums.CardanoNativeScriptHashDisplayFormat import trezor.enums.CardanoNativeScriptHashDisplayFormat
trezor.enums.CardanoNativeScriptType trezor.enums.CardanoNativeScriptType

@ -1,18 +1,33 @@
from storage import cache, device from storage import cache, device
from trezor import wire from trezor import wire
from trezor.crypto import bip32, cardano from trezor.crypto import bip32, cardano
from trezor.enums import CardanoDerivationType
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.passphrase import get as get_passphrase from apps.common.seed import derive_and_store_roots, get_seed
from apps.common.seed import get_seed, derive_and_store_roots
from .helpers import paths from .helpers import paths
if False: if False:
from typing import Callable, Awaitable from typing import Callable, Awaitable, TypeVar, Union
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
from apps.common.keychain import MsgIn, MsgOut, Handler from apps.common.keychain import MsgOut, Handler
from trezor.messages import (
CardanoGetAddress,
CardanoGetPublicKey,
CardanoGetNativeScriptHash,
CardanoSignTxInit,
)
CardanoMessages = Union[
CardanoGetAddress,
CardanoGetPublicKey,
CardanoGetNativeScriptHash,
CardanoSignTxInit,
]
MsgIn = TypeVar("MsgIn", bound=CardanoMessages)
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]] HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
@ -97,7 +112,7 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(paths.MINTING_ROOT)] == paths.MINTING_ROOT return path[: len(paths.MINTING_ROOT)] == paths.MINTING_ROOT
def derive_and_store_secret(passphrase: str) -> None: def derive_and_store_secrets(passphrase: str) -> None:
assert device.is_initialized() assert device.is_initialized()
assert cache.get(cache.APP_COMMON_DERIVE_CARDANO) assert cache.get(cache.APP_COMMON_DERIVE_CARDANO)
@ -105,33 +120,61 @@ def derive_and_store_secret(passphrase: str) -> None:
# nothing to do for SLIP-39, where we can derive the root from the main seed # nothing to do for SLIP-39, where we can derive the root from the main seed
return return
icarus_trezor_secret = mnemonic.derive_cardano_icarus_trezor(passphrase) icarus_secret = mnemonic.derive_cardano_icarus(passphrase, trezor_derivation=False)
cache.set(cache.APP_CARDANO_SECRET, icarus_trezor_secret)
words = mnemonic.get_secret()
assert words is not None, "Mnemonic is not set"
# count ASCII spaces, add 1 to get number of words
words_count = sum(c == 0x20 for c in words) + 1
@cache.stored_async(cache.APP_CARDANO_SECRET) if words_count == 24:
async def _get_secret(ctx: wire.Context) -> bytes: icarus_trezor_secret = mnemonic.derive_cardano_icarus(
await derive_and_store_roots(ctx) passphrase, trezor_derivation=True
secret = cache.get(cache.APP_CARDANO_SECRET) )
assert secret is not None else:
icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_secret(ctx: wire.Context, cache_entry: int) -> bytes:
secret = cache.get(cache_entry)
if secret is None:
await derive_and_store_roots(ctx)
secret = cache.get(cache_entry)
assert secret is not None
return secret return secret
async def _get_keychain_bip39(ctx: wire.Context) -> Keychain: async def _get_keychain_bip39(
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if derivation_type == CardanoDerivationType.LEDGER:
seed = await get_seed(ctx)
return Keychain(cardano.from_seed_ledger(seed))
if not cache.get(cache.APP_COMMON_DERIVE_CARDANO): if not cache.get(cache.APP_COMMON_DERIVE_CARDANO):
raise wire.ProcessError("Cardano derivation is not enabled for this session") raise wire.ProcessError("Cardano derivation is not enabled for this session")
secret = await _get_secret(ctx) if derivation_type == CardanoDerivationType.ICARUS:
cache_entry = cache.APP_CARDANO_ICARUS_SECRET
else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
secret = await _get_secret(ctx, cache_entry)
root = cardano.from_secret(secret) root = cardano.from_secret(secret)
return Keychain(root) return Keychain(root)
async def get_keychain(ctx: wire.Context) -> Keychain: async def get_keychain(
ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain:
if mnemonic.is_bip39(): if mnemonic.is_bip39():
return await _get_keychain_bip39(ctx) return await _get_keychain_bip39(ctx, derivation_type)
else: else:
# derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0022.md # derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0022.md
seed = await get_seed(ctx) seed = await get_seed(ctx)
@ -140,7 +183,7 @@ async def get_keychain(ctx: wire.Context) -> Keychain:
def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx) keychain = await get_keychain(ctx, msg.derivation_type)
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain)
return wrapper return wrapper

@ -52,7 +52,7 @@ if not utils.BITCOIN_ONLY:
need_seed = not cache.is_set(cache.APP_COMMON_SEED) need_seed = not cache.is_set(cache.APP_COMMON_SEED)
need_cardano_secret = cache.get( need_cardano_secret = cache.get(
cache.APP_COMMON_DERIVE_CARDANO cache.APP_COMMON_DERIVE_CARDANO
) and not cache.is_set(cache.APP_CARDANO_SECRET) ) and not cache.is_set(cache.APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret: if not need_seed and not need_cardano_secret:
return return
@ -64,9 +64,9 @@ if not utils.BITCOIN_ONLY:
cache.set(cache.APP_COMMON_SEED, common_seed) cache.set(cache.APP_COMMON_SEED, common_seed)
if need_cardano_secret: if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secret from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secret(passphrase) derive_and_store_secrets(passphrase)
@cache.stored_async(cache.APP_COMMON_SEED) @cache.stored_async(cache.APP_COMMON_SEED)
async def get_seed(ctx: wire.Context) -> bytes: async def get_seed(ctx: wire.Context) -> bytes:

@ -20,11 +20,12 @@ _SESSION_ID_LENGTH = 32
# Traditional cache keys # Traditional cache keys
APP_COMMON_SEED = 0 APP_COMMON_SEED = 0
APP_COMMON_DERIVE_CARDANO = 1 APP_COMMON_AUTHORIZATION_TYPE = 1
APP_CARDANO_SECRET = 2 APP_COMMON_AUTHORIZATION_DATA = 2
APP_MONERO_LIVE_REFRESH = 3 APP_COMMON_DERIVE_CARDANO = 3
APP_COMMON_AUTHORIZATION_TYPE = 4 APP_CARDANO_ICARUS_SECRET = 4
APP_COMMON_AUTHORIZATION_DATA = 5 APP_CARDANO_ICARUS_TREZOR_SECRET = 5
APP_MONERO_LIVE_REFRESH = 6
# Keys that are valid across sessions # Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
@ -92,11 +93,12 @@ class SessionCache(DataCache):
self.session_id = bytearray(_SESSION_ID_LENGTH) self.session_id = bytearray(_SESSION_ID_LENGTH)
self.fields = ( self.fields = (
64, # APP_COMMON_SEED 64, # APP_COMMON_SEED
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_SECRET
1, # APP_MONERO_LIVE_REFRESH
2, # APP_COMMON_AUTHORIZATION_TYPE 2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA 128, # APP_COMMON_AUTHORIZATION_DATA
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
) )
self.last_usage = 0 self.last_usage = 0
super().__init__() super().__init__()

@ -0,0 +1,7 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
LEDGER = 0
ICARUS = 1
ICARUS_TREZOR = 2

@ -328,6 +328,11 @@ if TYPE_CHECKING:
TXORIGINPUT = 5 TXORIGINPUT = 5
TXORIGOUTPUT = 6 TXORIGOUTPUT = 6
class CardanoDerivationType(IntEnum):
LEDGER = 0
ICARUS = 1
ICARUS_TREZOR = 2
class CardanoAddressType(IntEnum): class CardanoAddressType(IntEnum):
BASE = 0 BASE = 0
BASE_SCRIPT_KEY = 1 BASE_SCRIPT_KEY = 1

@ -27,6 +27,7 @@ if TYPE_CHECKING:
from trezor.enums import Capability # noqa: F401 from trezor.enums import Capability # noqa: F401
from trezor.enums import CardanoAddressType # noqa: F401 from trezor.enums import CardanoAddressType # noqa: F401
from trezor.enums import CardanoCertificateType # noqa: F401 from trezor.enums import CardanoCertificateType # noqa: F401
from trezor.enums import CardanoDerivationType # noqa: F401
from trezor.enums import CardanoNativeScriptHashDisplayFormat # noqa: F401 from trezor.enums import CardanoNativeScriptHashDisplayFormat # noqa: F401
from trezor.enums import CardanoNativeScriptType # noqa: F401 from trezor.enums import CardanoNativeScriptType # noqa: F401
from trezor.enums import CardanoPoolRelayType # noqa: F401 from trezor.enums import CardanoPoolRelayType # noqa: F401
@ -1093,12 +1094,14 @@ if TYPE_CHECKING:
class CardanoGetNativeScriptHash(protobuf.MessageType): class CardanoGetNativeScriptHash(protobuf.MessageType):
script: "CardanoNativeScript" script: "CardanoNativeScript"
display_format: "CardanoNativeScriptHashDisplayFormat" display_format: "CardanoNativeScriptHashDisplayFormat"
derivation_type: "CardanoDerivationType"
def __init__( def __init__(
self, self,
*, *,
script: "CardanoNativeScript", script: "CardanoNativeScript",
display_format: "CardanoNativeScriptHashDisplayFormat", display_format: "CardanoNativeScriptHashDisplayFormat",
derivation_type: "CardanoDerivationType",
) -> None: ) -> None:
pass pass
@ -1151,6 +1154,7 @@ if TYPE_CHECKING:
protocol_magic: "int" protocol_magic: "int"
network_id: "int" network_id: "int"
address_parameters: "CardanoAddressParametersType" address_parameters: "CardanoAddressParametersType"
derivation_type: "CardanoDerivationType"
def __init__( def __init__(
self, self,
@ -1158,6 +1162,7 @@ if TYPE_CHECKING:
protocol_magic: "int", protocol_magic: "int",
network_id: "int", network_id: "int",
address_parameters: "CardanoAddressParametersType", address_parameters: "CardanoAddressParametersType",
derivation_type: "CardanoDerivationType",
show_display: "bool | None" = None, show_display: "bool | None" = None,
) -> None: ) -> None:
pass pass
@ -1183,10 +1188,12 @@ if TYPE_CHECKING:
class CardanoGetPublicKey(protobuf.MessageType): class CardanoGetPublicKey(protobuf.MessageType):
address_n: "list[int]" address_n: "list[int]"
show_display: "bool | None" show_display: "bool | None"
derivation_type: "CardanoDerivationType"
def __init__( def __init__(
self, self,
*, *,
derivation_type: "CardanoDerivationType",
address_n: "list[int] | None" = None, address_n: "list[int] | None" = None,
show_display: "bool | None" = None, show_display: "bool | None" = None,
) -> None: ) -> None:
@ -1226,6 +1233,7 @@ if TYPE_CHECKING:
validity_interval_start: "int | None" validity_interval_start: "int | None"
witness_requests_count: "int" witness_requests_count: "int"
minting_asset_groups_count: "int" minting_asset_groups_count: "int"
derivation_type: "CardanoDerivationType"
def __init__( def __init__(
self, self,
@ -1241,6 +1249,7 @@ if TYPE_CHECKING:
has_auxiliary_data: "bool", has_auxiliary_data: "bool",
witness_requests_count: "int", witness_requests_count: "int",
minting_asset_groups_count: "int", minting_asset_groups_count: "int",
derivation_type: "CardanoDerivationType",
ttl: "int | None" = None, ttl: "int | None" = None,
validity_interval_start: "int | None" = None, validity_interval_start: "int | None" = None,
) -> None: ) -> None:

@ -341,6 +341,12 @@ class RequestType(IntEnum):
TXORIGOUTPUT = 6 TXORIGOUTPUT = 6
class CardanoDerivationType(IntEnum):
LEDGER = 0
ICARUS = 1
ICARUS_TREZOR = 2
class CardanoAddressType(IntEnum): class CardanoAddressType(IntEnum):
BASE = 0 BASE = 0
BASE_SCRIPT_KEY = 1 BASE_SCRIPT_KEY = 1
@ -1957,6 +1963,7 @@ class CardanoGetNativeScriptHash(protobuf.MessageType):
FIELDS = { FIELDS = {
1: protobuf.Field("script", "CardanoNativeScript", repeated=False, required=True), 1: protobuf.Field("script", "CardanoNativeScript", repeated=False, required=True),
2: protobuf.Field("display_format", "CardanoNativeScriptHashDisplayFormat", repeated=False, required=True), 2: protobuf.Field("display_format", "CardanoNativeScriptHashDisplayFormat", repeated=False, required=True),
3: protobuf.Field("derivation_type", "CardanoDerivationType", repeated=False, required=True),
} }
def __init__( def __init__(
@ -1964,9 +1971,11 @@ class CardanoGetNativeScriptHash(protobuf.MessageType):
*, *,
script: "CardanoNativeScript", script: "CardanoNativeScript",
display_format: "CardanoNativeScriptHashDisplayFormat", display_format: "CardanoNativeScriptHashDisplayFormat",
derivation_type: "CardanoDerivationType",
) -> None: ) -> None:
self.script = script self.script = script
self.display_format = display_format self.display_format = display_format
self.derivation_type = derivation_type
class CardanoNativeScriptHash(protobuf.MessageType): class CardanoNativeScriptHash(protobuf.MessageType):
@ -2022,6 +2031,7 @@ class CardanoGetAddress(protobuf.MessageType):
3: protobuf.Field("protocol_magic", "uint32", repeated=False, required=True), 3: protobuf.Field("protocol_magic", "uint32", repeated=False, required=True),
4: protobuf.Field("network_id", "uint32", repeated=False, required=True), 4: protobuf.Field("network_id", "uint32", repeated=False, required=True),
5: protobuf.Field("address_parameters", "CardanoAddressParametersType", repeated=False, required=True), 5: protobuf.Field("address_parameters", "CardanoAddressParametersType", repeated=False, required=True),
6: protobuf.Field("derivation_type", "CardanoDerivationType", repeated=False, required=True),
} }
def __init__( def __init__(
@ -2030,11 +2040,13 @@ class CardanoGetAddress(protobuf.MessageType):
protocol_magic: "int", protocol_magic: "int",
network_id: "int", network_id: "int",
address_parameters: "CardanoAddressParametersType", address_parameters: "CardanoAddressParametersType",
derivation_type: "CardanoDerivationType",
show_display: Optional["bool"] = False, show_display: Optional["bool"] = False,
) -> None: ) -> None:
self.protocol_magic = protocol_magic self.protocol_magic = protocol_magic
self.network_id = network_id self.network_id = network_id
self.address_parameters = address_parameters self.address_parameters = address_parameters
self.derivation_type = derivation_type
self.show_display = show_display self.show_display = show_display
@ -2057,15 +2069,18 @@ class CardanoGetPublicKey(protobuf.MessageType):
FIELDS = { FIELDS = {
1: protobuf.Field("address_n", "uint32", repeated=True, required=False), 1: protobuf.Field("address_n", "uint32", repeated=True, required=False),
2: protobuf.Field("show_display", "bool", repeated=False, required=False), 2: protobuf.Field("show_display", "bool", repeated=False, required=False),
3: protobuf.Field("derivation_type", "CardanoDerivationType", repeated=False, required=True),
} }
def __init__( def __init__(
self, self,
*, *,
derivation_type: "CardanoDerivationType",
address_n: Optional[List["int"]] = None, address_n: Optional[List["int"]] = None,
show_display: Optional["bool"] = None, show_display: Optional["bool"] = None,
) -> None: ) -> None:
self.address_n = address_n if address_n is not None else [] self.address_n = address_n if address_n is not None else []
self.derivation_type = derivation_type
self.show_display = show_display self.show_display = show_display
@ -2102,6 +2117,7 @@ class CardanoSignTxInit(protobuf.MessageType):
11: protobuf.Field("validity_interval_start", "uint64", repeated=False, required=False), 11: protobuf.Field("validity_interval_start", "uint64", repeated=False, required=False),
12: protobuf.Field("witness_requests_count", "uint32", repeated=False, required=True), 12: protobuf.Field("witness_requests_count", "uint32", repeated=False, required=True),
13: protobuf.Field("minting_asset_groups_count", "uint32", repeated=False, required=True), 13: protobuf.Field("minting_asset_groups_count", "uint32", repeated=False, required=True),
14: protobuf.Field("derivation_type", "CardanoDerivationType", repeated=False, required=True),
} }
def __init__( def __init__(
@ -2118,6 +2134,7 @@ class CardanoSignTxInit(protobuf.MessageType):
has_auxiliary_data: "bool", has_auxiliary_data: "bool",
witness_requests_count: "int", witness_requests_count: "int",
minting_asset_groups_count: "int", minting_asset_groups_count: "int",
derivation_type: "CardanoDerivationType",
ttl: Optional["int"] = None, ttl: Optional["int"] = None,
validity_interval_start: Optional["int"] = None, validity_interval_start: Optional["int"] = None,
) -> None: ) -> None:
@ -2132,6 +2149,7 @@ class CardanoSignTxInit(protobuf.MessageType):
self.has_auxiliary_data = has_auxiliary_data self.has_auxiliary_data = has_auxiliary_data
self.witness_requests_count = witness_requests_count self.witness_requests_count = witness_requests_count
self.minting_asset_groups_count = minting_asset_groups_count self.minting_asset_groups_count = minting_asset_groups_count
self.derivation_type = derivation_type
self.ttl = ttl self.ttl = ttl
self.validity_interval_start = validity_interval_start self.validity_interval_start = validity_interval_start

Loading…
Cancel
Save