refactor(core,crypto): make public key derivation functions return

status
pull/1849/head
Ondřej Vejpustek 3 years ago
parent 172f399b29
commit 15bb085509

@ -71,27 +71,32 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_publickey(size_t n_args,
mp_raise_ValueError("Invalid length of secret key");
}
vstr_t pk = {0};
int ret = 0;
bool compressed = n_args < 2 || args[1] == mp_const_true;
if (compressed) {
vstr_init_len(&pk, 33);
#ifdef USE_SECP256K1_ZKP_ECDSA
zkp_ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
ret = zkp_ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
#else
ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
ret = ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
#endif
} else {
vstr_init_len(&pk, 65);
#ifdef USE_SECP256K1_ZKP_ECDSA
zkp_ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
ret = zkp_ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
#else
ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
ret = ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf);
#endif
}
if (0 != ret) {
vstr_clear(&pk);
mp_raise_ValueError("Invalid secret key");
}
return mp_obj_new_str_from_vstr(&mp_type_bytes, &pk);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(

@ -403,13 +403,16 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
}
// res = k * p
void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res) {
// returns 0 on success
int point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res) {
// this algorithm is loosely based on
// Katsuyuki Okeya and Tsuyoshi Takagi, The Width-w NAF Method Provides
// Small Memory and Fast Elliptic Scalar Multiplications Secure against
// Side Channel Attacks.
assert(bn_is_less(k, &curve->order));
if (!bn_is_less(k, &curve->order)) {
return 1;
}
int i = 0, j = 0;
static CONFIDENTIAL bignum256 a;
@ -441,7 +444,7 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
// special case 0*p: just return zero. We don't care about constant time.
if (!is_non_zero) {
point_set_infinity(res);
return;
return 1;
}
// Now a = k + 2^256 (mod curve->order) and a is odd.
@ -522,15 +525,20 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
jacobian_to_curve(&jres, res, prime);
memzero(&a, sizeof(a));
memzero(&jres, sizeof(jres));
return 0;
}
#if USE_PRECOMPUTED_CP
// res = k * G
// k must be a normalized number with 0 <= k < curve->order
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res) {
assert(bn_is_less(k, &curve->order));
// returns 0 on success
int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res) {
if (!bn_is_less(k, &curve->order)) {
return 1;
}
int i = {0}, j = {0};
static CONFIDENTIAL bignum256 a;
@ -558,7 +566,7 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
// special case 0*G: just return zero. We don't care about constant time.
if (!is_non_zero) {
point_set_infinity(res);
return;
return 0;
}
// Now a = k + 2^256 (mod curve->order) and a is odd.
@ -611,13 +619,15 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
jacobian_to_curve(&jres, res, prime);
memzero(&a, sizeof(a));
memzero(&jres, sizeof(jres));
return 0;
}
#else
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res) {
point_multiply(curve, k, &curve->G, res);
int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res) {
return point_multiply(curve, k, &curve->G, res);
}
#endif
@ -754,33 +764,43 @@ int ecdsa_sign_digest(const ecdsa_curve *curve, const uint8_t *priv_key,
return -1;
}
void ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
// returns 0 on success
int ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
curve_point R = {0};
bignum256 k = {0};
bn_read_be(priv_key, &k);
// compute k*G
scalar_multiply(curve, &k, &R);
if (scalar_multiply(curve, &k, &R) != 0) {
memzero(&k, sizeof(k));
return 1;
}
pub_key[0] = 0x02 | (R.y.val[0] & 0x01);
bn_write_be(&R.x, pub_key + 1);
memzero(&R, sizeof(R));
memzero(&k, sizeof(k));
return 0;
}
void ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
// returns 0 on success
int ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
curve_point R = {0};
bignum256 k = {0};
bn_read_be(priv_key, &k);
// compute k*G
scalar_multiply(curve, &k, &R);
if (scalar_multiply(curve, &k, &R) != 0) {
memzero(&k, sizeof(k));
return 1;
}
pub_key[0] = 0x04;
bn_write_be(&R.x, pub_key + 1);
bn_write_be(&R.y, pub_key + 33);
memzero(&R, sizeof(R));
memzero(&k, sizeof(k));
return 0;
}
int ecdsa_uncompress_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key,

@ -65,14 +65,14 @@ void point_copy(const curve_point *cp1, curve_point *cp2);
void point_add(const ecdsa_curve *curve, const curve_point *cp1,
curve_point *cp2);
void point_double(const ecdsa_curve *curve, curve_point *cp);
void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res);
int point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res);
void point_set_infinity(curve_point *p);
int point_is_infinity(const curve_point *p);
int point_is_equal(const curve_point *p, const curve_point *q);
int point_is_negative_of(const curve_point *p, const curve_point *q);
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res);
int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res);
int ecdh_multiply(const ecdsa_curve *curve, const uint8_t *priv_key,
const uint8_t *pub_key, uint8_t *session_key);
void compress_coords(const curve_point *cp, uint8_t *compressed);
@ -88,10 +88,10 @@ int ecdsa_sign(const ecdsa_curve *curve, HasherType hasher_sign,
int ecdsa_sign_digest(const ecdsa_curve *curve, const uint8_t *priv_key,
const uint8_t *digest, uint8_t *sig, uint8_t *pby,
int (*is_canonical)(uint8_t by, uint8_t sig[64]));
void ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key);
void ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key);
int ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key);
int ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key);
void ecdsa_get_pubkeyhash(const uint8_t *pub_key, HasherType hasher_pubkey,
uint8_t *pubkeyhash);
void ecdsa_get_address_raw(const uint8_t *pub_key, uint32_t version,

@ -3429,19 +3429,20 @@ START_TEST(test_bip32_decred_vector_2) {
}
END_TEST
static void test_ecdsa_get_public_key33_helper(
void (*ecdsa_get_public_key33_fn)(const ecdsa_curve *, const uint8_t *,
uint8_t *)) {
uint8_t privkey[32];
uint8_t pubkey[65];
static void test_ecdsa_get_public_key33_helper(int (*ecdsa_get_public_key33_fn)(
const ecdsa_curve *, const uint8_t *, uint8_t *)) {
uint8_t privkey[32] = {0};
uint8_t pubkey[65] = {0};
const ecdsa_curve *curve = &secp256k1;
int res = 0;
memcpy(
privkey,
fromhex(
"c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"),
32);
ecdsa_get_public_key33_fn(curve, privkey, pubkey);
res = ecdsa_get_public_key33_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq(
pubkey,
fromhex(
@ -3453,7 +3454,8 @@ static void test_ecdsa_get_public_key33_helper(
fromhex(
"3b90a4de80fb00d77795762c389d1279d4b4ab5992ae3cde6bc12ca63116f74c"),
32);
ecdsa_get_public_key33_fn(curve, privkey, pubkey);
res = ecdsa_get_public_key33_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq(
pubkey,
fromhex(
@ -3471,19 +3473,20 @@ START_TEST(test_zkp_ecdsa_get_public_key33) {
}
END_TEST
static void test_ecdsa_get_public_key65_helper(
void (*ecdsa_get_public_key65_fn)(const ecdsa_curve *, const uint8_t *,
uint8_t *)) {
uint8_t privkey[32];
uint8_t pubkey[65];
static void test_ecdsa_get_public_key65_helper(int (*ecdsa_get_public_key65_fn)(
const ecdsa_curve *, const uint8_t *, uint8_t *)) {
uint8_t privkey[32] = {0};
uint8_t pubkey[65] = {0};
const ecdsa_curve *curve = &secp256k1;
int res = 0;
memcpy(
privkey,
fromhex(
"c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"),
32);
ecdsa_get_public_key65_fn(curve, privkey, pubkey);
res = ecdsa_get_public_key65_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq(
pubkey,
fromhex(
@ -6337,11 +6340,11 @@ static void test_point_mult_curve(const ecdsa_curve *curve) {
/* test distributivity: (a + b)P = aP + bP */
bn_mod(&a, &curve->order);
bn_mod(&b, &curve->order);
point_multiply(curve, &a, &p, &p1);
point_multiply(curve, &b, &p, &p2);
ck_assert_int_eq(point_multiply(curve, &a, &p, &p1), 0);
ck_assert_int_eq(point_multiply(curve, &b, &p, &p2), 0);
bn_addmod(&a, &b, &curve->order);
bn_mod(&a, &curve->order);
point_multiply(curve, &a, &p, &p3);
ck_assert_int_eq(point_multiply(curve, &a, &p, &p3), 0);
point_add(curve, &p1, &p2);
ck_assert_mem_eq(&p2, &p3, sizeof(curve_point));
// new "random" numbers and a "random" point
@ -6368,17 +6371,17 @@ static void test_scalar_point_mult_curve(const ecdsa_curve *curve) {
*/
bn_mod(&a, &curve->order);
bn_mod(&b, &curve->order);
scalar_multiply(curve, &a, &p1);
point_multiply(curve, &b, &p1, &p1);
ck_assert_int_eq(scalar_multiply(curve, &a, &p1), 0);
ck_assert_int_eq(point_multiply(curve, &b, &p1, &p1), 0);
scalar_multiply(curve, &b, &p2);
point_multiply(curve, &a, &p2, &p2);
ck_assert_int_eq(scalar_multiply(curve, &b, &p2), 0);
ck_assert_int_eq(point_multiply(curve, &a, &p2, &p2), 0);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));
bn_multiply(&a, &b, &curve->order);
bn_mod(&b, &curve->order);
scalar_multiply(curve, &b, &p2);
ck_assert_int_eq(scalar_multiply(curve, &b, &p2), 0);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));

@ -44,9 +44,10 @@ static bool is_zero_digest(const uint8_t *digest) {
// curve has to be &secp256k1
// private_key_bytes has 32 bytes
// public_key_bytes has 33 bytes
void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
// returns 0 on success
int zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
assert(curve == &secp256k1);
int result = 0;
@ -75,16 +76,17 @@ void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
}
memzero(&public_key, sizeof(public_key));
assert(result == 0);
return result;
}
// ECDSA uncompressed public key derivation
// curve has to be &secp256k1
// private_key_bytes has 32 bytes
// public_key_bytes has 65 bytes
void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
// returns 0 on success
int zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
assert(curve == &secp256k1);
int result = 0;
@ -113,7 +115,7 @@ void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
}
memzero(&public_key, sizeof(public_key));
assert(result == 0);
return result;
}
// ECDSA signing

@ -5,12 +5,12 @@
#include "hasher.h"
void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes);
void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes);
int zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes);
int zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
uint8_t *public_key_bytes);
int zkp_ecdsa_sign_digest(const ecdsa_curve *curve,
const uint8_t *private_key_bytes,
const uint8_t *digest, uint8_t *signature_bytes,

Loading…
Cancel
Save