refactor(trezorlib/cardano): rename create_* functions

To avoid ambiguity between creating messages from parameters and parsing them from JSON.
pull/1582/head
gabrielkerekes 3 years ago committed by matejcik
parent 905970fd6a
commit a3d0016a2f

@ -73,7 +73,7 @@ def create_address_parameters(
raise ValueError("Unknown address type") raise ValueError("Unknown address type")
if address_type == messages.CardanoAddressType.POINTER: if address_type == messages.CardanoAddressType.POINTER:
certificate_pointer = create_certificate_pointer( certificate_pointer = _create_certificate_pointer(
block_index, tx_index, certificate_index block_index, tx_index, certificate_index
) )
@ -86,7 +86,7 @@ def create_address_parameters(
) )
def create_certificate_pointer( def _create_certificate_pointer(
block_index: int, tx_index: int, certificate_index: int block_index: int, tx_index: int, certificate_index: int
) -> messages.CardanoBlockchainPointerType: ) -> messages.CardanoBlockchainPointerType:
if block_index is None or tx_index is None or certificate_index is None: if block_index is None or tx_index is None or certificate_index is None:
@ -97,7 +97,7 @@ def create_certificate_pointer(
) )
def create_input(tx_input) -> messages.CardanoTxInputType: def parse_input(tx_input) -> messages.CardanoTxInputType:
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT): if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The input is missing some fields") raise ValueError("The input is missing some fields")
@ -108,7 +108,7 @@ def create_input(tx_input) -> messages.CardanoTxInputType:
) )
def create_output(output) -> messages.CardanoTxOutputType: def parse_output(output) -> messages.CardanoTxOutputType:
contains_address = "address" in output contains_address = "address" in output
contains_address_type = "addressType" in output contains_address_type = "addressType" in output
@ -124,10 +124,10 @@ def create_output(output) -> messages.CardanoTxOutputType:
if contains_address: if contains_address:
address = output["address"] address = output["address"]
else: else:
address_parameters = _create_address_parameters_internal(output) address_parameters = _parse_address_parameters(output)
if "token_bundle" in output: if "token_bundle" in output:
token_bundle = _create_token_bundle(output["token_bundle"]) token_bundle = _parse_token_bundle(output["token_bundle"])
return messages.CardanoTxOutputType( return messages.CardanoTxOutputType(
address=address, address=address,
@ -137,7 +137,7 @@ def create_output(output) -> messages.CardanoTxOutputType:
) )
def _create_token_bundle(token_bundle) -> List[messages.CardanoAssetGroupType]: def _parse_token_bundle(token_bundle) -> List[messages.CardanoAssetGroupType]:
result = [] result = []
for token_group in token_bundle: for token_group in token_bundle:
if not all(k in token_group for k in REQUIRED_FIELDS_TOKEN_GROUP): if not all(k in token_group for k in REQUIRED_FIELDS_TOKEN_GROUP):
@ -146,14 +146,14 @@ def _create_token_bundle(token_bundle) -> List[messages.CardanoAssetGroupType]:
result.append( result.append(
messages.CardanoAssetGroupType( messages.CardanoAssetGroupType(
policy_id=bytes.fromhex(token_group["policy_id"]), policy_id=bytes.fromhex(token_group["policy_id"]),
tokens=_create_tokens(token_group["tokens"]), tokens=_parse_tokens(token_group["tokens"]),
) )
) )
return result return result
def _create_tokens(tokens) -> List[messages.CardanoTokenType]: def _parse_tokens(tokens) -> List[messages.CardanoTokenType]:
result = [] result = []
for token in tokens: for token in tokens:
if not all(k in token for k in REQUIRED_FIELDS_TOKEN): if not all(k in token for k in REQUIRED_FIELDS_TOKEN):
@ -169,7 +169,7 @@ def _create_tokens(tokens) -> List[messages.CardanoTokenType]:
return result return result
def _create_address_parameters_internal( def _parse_address_parameters(
address_parameters, address_parameters,
) -> messages.CardanoAddressParametersType: ) -> messages.CardanoAddressParametersType:
if "path" not in address_parameters: if "path" not in address_parameters:
@ -190,7 +190,7 @@ def _create_address_parameters_internal(
) )
def create_certificate(certificate) -> messages.CardanoTxCertificateType: def parse_certificate(certificate) -> messages.CardanoTxCertificateType:
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError( CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
"The certificate is missing some fields" "The certificate is missing some fields"
) )
@ -248,11 +248,11 @@ def create_certificate(certificate) -> messages.CardanoTxCertificateType:
reward_account=pool_parameters["reward_account"], reward_account=pool_parameters["reward_account"],
metadata=pool_metadata, metadata=pool_metadata,
owners=[ owners=[
_create_pool_owner(pool_owner) _parse_pool_owner(pool_owner)
for pool_owner in pool_parameters.get("owners", []) for pool_owner in pool_parameters.get("owners", [])
], ],
relays=[ relays=[
_create_pool_relay(pool_relay) _parse_pool_relay(pool_relay)
for pool_relay in pool_parameters.get("relays", []) for pool_relay in pool_parameters.get("relays", [])
] ]
if "relays" in pool_parameters if "relays" in pool_parameters
@ -263,7 +263,7 @@ def create_certificate(certificate) -> messages.CardanoTxCertificateType:
raise ValueError("Unknown certificate type") raise ValueError("Unknown certificate type")
def _create_pool_owner(pool_owner) -> messages.CardanoPoolOwnerType: def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwnerType:
if "staking_key_path" in pool_owner: if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwnerType( return messages.CardanoPoolOwnerType(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
@ -274,7 +274,7 @@ def _create_pool_owner(pool_owner) -> messages.CardanoPoolOwnerType:
) )
def _create_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType: def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType:
pool_relay_type = int(pool_relay["type"]) pool_relay_type = int(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
@ -310,7 +310,7 @@ def _create_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType:
raise ValueError("Unknown pool relay type") raise ValueError("Unknown pool relay type")
def create_withdrawal(withdrawal) -> messages.CardanoTxWithdrawalType: def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawalType:
if not all(k in withdrawal for k in REQUIRED_FIELDS_WITHDRAWAL): if not all(k in withdrawal for k in REQUIRED_FIELDS_WITHDRAWAL):
raise ValueError("Withdrawal is missing some fields") raise ValueError("Withdrawal is missing some fields")
@ -321,7 +321,7 @@ def create_withdrawal(withdrawal) -> messages.CardanoTxWithdrawalType:
) )
def create_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType: def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType:
if auxiliary_data is None: if auxiliary_data is None:
return None return None
@ -349,7 +349,7 @@ def create_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType
), ),
staking_path=tools.parse_path(catalyst_registration["staking_path"]), staking_path=tools.parse_path(catalyst_registration["staking_path"]),
nonce=catalyst_registration["nonce"], nonce=catalyst_registration["nonce"],
reward_address_parameters=_create_address_parameters_internal( reward_address_parameters=_parse_address_parameters(
catalyst_registration["reward_address_parameters"] catalyst_registration["reward_address_parameters"]
), ),
) )

@ -54,20 +54,20 @@ def sign_tx(client, file, protocol_magic, network_id, testnet):
protocol_magic = cardano.PROTOCOL_MAGICS["testnet"] protocol_magic = cardano.PROTOCOL_MAGICS["testnet"]
network_id = cardano.NETWORK_IDS["testnet"] network_id = cardano.NETWORK_IDS["testnet"]
inputs = [cardano.create_input(input) for input in transaction["inputs"]] inputs = [cardano.parse_input(input) for input in transaction["inputs"]]
outputs = [cardano.create_output(output) for output in transaction["outputs"]] outputs = [cardano.parse_output(output) for output in transaction["outputs"]]
fee = transaction["fee"] fee = transaction["fee"]
ttl = transaction.get("ttl") ttl = transaction.get("ttl")
validity_interval_start = transaction.get("validity_interval_start") validity_interval_start = transaction.get("validity_interval_start")
certificates = [ certificates = [
cardano.create_certificate(certificate) cardano.parse_certificate(certificate)
for certificate in transaction.get("certificates", ()) for certificate in transaction.get("certificates", ())
] ]
withdrawals = [ withdrawals = [
cardano.create_withdrawal(withdrawal) cardano.parse_withdrawal(withdrawal)
for withdrawal in transaction.get("withdrawals", ()) for withdrawal in transaction.get("withdrawals", ())
] ]
auxiliary_data = cardano.create_auxiliary_data(transaction.get("auxiliary_data")) auxiliary_data = cardano.parse_auxiliary_data(transaction.get("auxiliary_data"))
signed_transaction = cardano.sign_tx( signed_transaction = cardano.sign_tx(
client, client,

@ -34,11 +34,11 @@ pytestmark = [
"cardano/sign_tx.slip39.json", "cardano/sign_tx.slip39.json",
) )
def test_cardano_sign_tx(client, parameters, result): def test_cardano_sign_tx(client, parameters, result):
inputs = [cardano.create_input(i) for i in parameters["inputs"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]]
outputs = [cardano.create_output(o) for o in parameters["outputs"]] outputs = [cardano.parse_output(o) for o in parameters["outputs"]]
certificates = [cardano.create_certificate(c) for c in parameters["certificates"]] certificates = [cardano.parse_certificate(c) for c in parameters["certificates"]]
withdrawals = [cardano.create_withdrawal(w) for w in parameters["withdrawals"]] withdrawals = [cardano.parse_withdrawal(w) for w in parameters["withdrawals"]]
auxiliary_data = cardano.create_auxiliary_data(parameters["auxiliary_data"]) auxiliary_data = cardano.parse_auxiliary_data(parameters["auxiliary_data"])
input_flow = parameters.get("input_flow", ()) input_flow = parameters.get("input_flow", ())
@ -73,11 +73,11 @@ def test_cardano_sign_tx(client, parameters, result):
"cardano/sign_tx.failed.json", "cardano/sign_tx_stake_pool_registration.failed.json" "cardano/sign_tx.failed.json", "cardano/sign_tx_stake_pool_registration.failed.json"
) )
def test_cardano_sign_tx_failed(client, parameters, result): def test_cardano_sign_tx_failed(client, parameters, result):
inputs = [cardano.create_input(i) for i in parameters["inputs"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]]
outputs = [cardano.create_output(o) for o in parameters["outputs"]] outputs = [cardano.parse_output(o) for o in parameters["outputs"]]
certificates = [cardano.create_certificate(c) for c in parameters["certificates"]] certificates = [cardano.parse_certificate(c) for c in parameters["certificates"]]
withdrawals = [cardano.create_withdrawal(w) for w in parameters["withdrawals"]] withdrawals = [cardano.parse_withdrawal(w) for w in parameters["withdrawals"]]
auxiliary_data = cardano.create_auxiliary_data(parameters["auxiliary_data"]) auxiliary_data = cardano.parse_auxiliary_data(parameters["auxiliary_data"])
input_flow = parameters.get("input_flow", ()) input_flow = parameters.get("input_flow", ())
@ -102,11 +102,11 @@ def test_cardano_sign_tx_failed(client, parameters, result):
@parametrize_using_common_fixtures("cardano/sign_tx.chunked.json") @parametrize_using_common_fixtures("cardano/sign_tx.chunked.json")
def test_cardano_sign_tx_with_multiple_chunks(client, parameters, result): def test_cardano_sign_tx_with_multiple_chunks(client, parameters, result):
inputs = [cardano.create_input(i) for i in parameters["inputs"]] inputs = [cardano.parse_input(i) for i in parameters["inputs"]]
outputs = [cardano.create_output(o) for o in parameters["outputs"]] outputs = [cardano.parse_output(o) for o in parameters["outputs"]]
certificates = [cardano.create_certificate(c) for c in parameters["certificates"]] certificates = [cardano.parse_certificate(c) for c in parameters["certificates"]]
withdrawals = [cardano.create_withdrawal(w) for w in parameters["withdrawals"]] withdrawals = [cardano.parse_withdrawal(w) for w in parameters["withdrawals"]]
auxiliary_data = cardano.create_auxiliary_data(parameters["auxiliary_data"]) auxiliary_data = cardano.parse_auxiliary_data(parameters["auxiliary_data"])
input_flow = parameters.get("input_flow", ()) input_flow = parameters.get("input_flow", ())

Loading…
Cancel
Save