diff --git a/firmware/fsm.c b/firmware/fsm.c index 558761ad87..bf233031ef 100644 --- a/firmware/fsm.c +++ b/firmware/fsm.c @@ -89,7 +89,7 @@ const CoinType *fsm_getCoin(const char *name) return coin; } -HDNode *fsm_getRootNode(void) +const HDNode *fsm_getDerivedNode(uint32_t *address_n, size_t address_n_count) { static HDNode node; if (!storage_getRootNode(&node)) { @@ -97,17 +97,15 @@ HDNode *fsm_getRootNode(void) layoutHome(); return 0; } - return &node; -} - -int fsm_deriveKey(HDNode *node, uint32_t *address_n, size_t address_n_count) -{ + if (!address_n || address_n_count == 0) { + return &node; + } size_t i; if (address_n_count > 3) { layoutProgressSwipe("Preparing keys", 0); } for (i = 0; i < address_n_count; i++) { - if (hdnode_private_ckd(node, address_n[i]) == 0) { + if (hdnode_private_ckd(&node, address_n[i]) == 0) { fsm_sendFailure(FailureType_Failure_Other, "Failed to derive private key"); layoutHome(); return 0; @@ -116,7 +114,7 @@ int fsm_deriveKey(HDNode *node, uint32_t *address_n, size_t address_n_count) layoutProgress("Preparing keys", 1000 * i / address_n_count); } } - return 1; + return &node; } void fsm_msgInitialize(Initialize *msg) @@ -279,9 +277,8 @@ void fsm_msgGetPublicKey(GetPublicKey *msg) { RESP_INIT(PublicKey); - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; resp->node.depth = node->depth; resp->node.fingerprint = node->fingerprint; @@ -364,7 +361,7 @@ void fsm_msgSignTx(SignTx *msg) const CoinType *coin = fsm_getCoin(msg->coin_name); if (!coin) return; - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(0, 0); if (!node) return; signing_init(msg->inputs_count, msg->outputs_count, coin, node); @@ -404,9 +401,8 @@ void fsm_msgCipherKeyValue(CipherKeyValue *msg) layoutHome(); return; } - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; bool encrypt = msg->has_encrypt && msg->encrypt; bool ask_on_encrypt = msg->has_ask_on_encrypt && msg->ask_on_encrypt; @@ -504,9 +500,8 @@ void fsm_msgGetAddress(GetAddress *msg) const CoinType *coin = fsm_getCoin(msg->coin_name); if (!coin) return; - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; if (msg->has_multisig) { layoutProgressSwipe("Preparing", 0); @@ -568,9 +563,8 @@ void fsm_msgSignMessage(SignMessage *msg) const CoinType *coin = fsm_getCoin(msg->coin_name); if (!coin) return; - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; layoutProgressSwipe("Signing", 0); if (cryptoMessageSign(msg->message.bytes, msg->message.size, node->private_key, resp->signature.bytes) == 0) { @@ -631,7 +625,7 @@ void fsm_msgEncryptMessage(EncryptMessage *msg) bool signing = msg->address_n_count > 0; RESP_INIT(EncryptedMessage); const CoinType *coin = 0; - HDNode *node = 0; + const HDNode *node = 0; uint8_t address_raw[21]; if (signing) { coin = coinByName(msg->coin_name); @@ -643,12 +637,11 @@ void fsm_msgEncryptMessage(EncryptMessage *msg) layoutHome(); return; } - node = fsm_getRootNode(); + node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; - - hdnode_fill_public_key(node); - ecdsa_get_address_raw(node->public_key, coin->address_type, address_raw); + uint8_t public_key[33]; + ecdsa_get_public_key33(node->private_key, public_key); + ecdsa_get_address_raw(public_key, coin->address_type, address_raw); } layoutEncryptMessage(msg->message.bytes, msg->message.size, signing); if (!protectButton(ButtonRequestType_ButtonRequest_ProtectCall, false)) { @@ -692,9 +685,8 @@ void fsm_msgDecryptMessage(DecryptMessage *msg) layoutHome(); return; } - HDNode *node = fsm_getRootNode(); + const HDNode *node = fsm_getDerivedNode(msg->address_n, msg->address_n_count); if (!node) return; - if (fsm_deriveKey(node, msg->address_n, msg->address_n_count) == 0) return; layoutProgressSwipe("Decrypting", 0); RESP_INIT(DecryptedMessage); diff --git a/firmware/signing.c b/firmware/signing.c index f3f8bf6dbb..e41d377e7f 100644 --- a/firmware/signing.c +++ b/firmware/signing.c @@ -196,7 +196,7 @@ void send_req_finished(void) msg_write(MessageType_MessageType_TxRequest, &resp); } -void signing_init(uint32_t _inputs_count, uint32_t _outputs_count, const CoinType *_coin, HDNode *_root) +void signing_init(uint32_t _inputs_count, uint32_t _outputs_count, const CoinType *_coin, const HDNode *_root) { inputs_count = _inputs_count; outputs_count = _outputs_count; diff --git a/firmware/signing.h b/firmware/signing.h index 84d609b877..bcb66fec6b 100644 --- a/firmware/signing.h +++ b/firmware/signing.h @@ -25,7 +25,7 @@ #include "bip32.h" #include "types.pb.h" -void signing_init(uint32_t _inputs_count, uint32_t _outputs_count, const CoinType *_coin, HDNode *_root); +void signing_init(uint32_t _inputs_count, uint32_t _outputs_count, const CoinType *_coin, const HDNode *_root); void signing_abort(void); void signing_txack(TransactionType *tx);