refactor(core,crypto): make public key derivation functions return

status
pull/1849/head
Ondřej Vejpustek 3 years ago
parent 172f399b29
commit 15bb085509

@ -71,27 +71,32 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_publickey(size_t n_args,
mp_raise_ValueError("Invalid length of secret key"); mp_raise_ValueError("Invalid length of secret key");
} }
vstr_t pk = {0}; vstr_t pk = {0};
int ret = 0;
bool compressed = n_args < 2 || args[1] == mp_const_true; bool compressed = n_args < 2 || args[1] == mp_const_true;
if (compressed) { if (compressed) {
vstr_init_len(&pk, 33); vstr_init_len(&pk, 33);
#ifdef USE_SECP256K1_ZKP_ECDSA #ifdef USE_SECP256K1_ZKP_ECDSA
zkp_ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf, ret = zkp_ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf); (uint8_t *)pk.buf);
#else #else
ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf, ret = ecdsa_get_public_key33(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf); (uint8_t *)pk.buf);
#endif #endif
} else { } else {
vstr_init_len(&pk, 65); vstr_init_len(&pk, 65);
#ifdef USE_SECP256K1_ZKP_ECDSA #ifdef USE_SECP256K1_ZKP_ECDSA
zkp_ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf, ret = zkp_ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf); (uint8_t *)pk.buf);
#else #else
ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf, ret = ecdsa_get_public_key65(&secp256k1, (const uint8_t *)sk.buf,
(uint8_t *)pk.buf); (uint8_t *)pk.buf);
#endif #endif
} }
if (0 != ret) {
vstr_clear(&pk);
mp_raise_ValueError("Invalid secret key");
}
return mp_obj_new_str_from_vstr(&mp_type_bytes, &pk); return mp_obj_new_str_from_vstr(&mp_type_bytes, &pk);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(

@ -403,13 +403,16 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
} }
// res = k * p // res = k * p
void point_multiply(const ecdsa_curve *curve, const bignum256 *k, // returns 0 on success
const curve_point *p, curve_point *res) { int point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res) {
// this algorithm is loosely based on // this algorithm is loosely based on
// Katsuyuki Okeya and Tsuyoshi Takagi, The Width-w NAF Method Provides // Katsuyuki Okeya and Tsuyoshi Takagi, The Width-w NAF Method Provides
// Small Memory and Fast Elliptic Scalar Multiplications Secure against // Small Memory and Fast Elliptic Scalar Multiplications Secure against
// Side Channel Attacks. // Side Channel Attacks.
assert(bn_is_less(k, &curve->order)); if (!bn_is_less(k, &curve->order)) {
return 1;
}
int i = 0, j = 0; int i = 0, j = 0;
static CONFIDENTIAL bignum256 a; static CONFIDENTIAL bignum256 a;
@ -441,7 +444,7 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
// special case 0*p: just return zero. We don't care about constant time. // special case 0*p: just return zero. We don't care about constant time.
if (!is_non_zero) { if (!is_non_zero) {
point_set_infinity(res); point_set_infinity(res);
return; return 1;
} }
// Now a = k + 2^256 (mod curve->order) and a is odd. // Now a = k + 2^256 (mod curve->order) and a is odd.
@ -522,15 +525,20 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k,
jacobian_to_curve(&jres, res, prime); jacobian_to_curve(&jres, res, prime);
memzero(&a, sizeof(a)); memzero(&a, sizeof(a));
memzero(&jres, sizeof(jres)); memzero(&jres, sizeof(jres));
return 0;
} }
#if USE_PRECOMPUTED_CP #if USE_PRECOMPUTED_CP
// res = k * G // res = k * G
// k must be a normalized number with 0 <= k < curve->order // k must be a normalized number with 0 <= k < curve->order
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k, // returns 0 on success
curve_point *res) { int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
assert(bn_is_less(k, &curve->order)); curve_point *res) {
if (!bn_is_less(k, &curve->order)) {
return 1;
}
int i = {0}, j = {0}; int i = {0}, j = {0};
static CONFIDENTIAL bignum256 a; static CONFIDENTIAL bignum256 a;
@ -558,7 +566,7 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
// special case 0*G: just return zero. We don't care about constant time. // special case 0*G: just return zero. We don't care about constant time.
if (!is_non_zero) { if (!is_non_zero) {
point_set_infinity(res); point_set_infinity(res);
return; return 0;
} }
// Now a = k + 2^256 (mod curve->order) and a is odd. // Now a = k + 2^256 (mod curve->order) and a is odd.
@ -611,13 +619,15 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
jacobian_to_curve(&jres, res, prime); jacobian_to_curve(&jres, res, prime);
memzero(&a, sizeof(a)); memzero(&a, sizeof(a));
memzero(&jres, sizeof(jres)); memzero(&jres, sizeof(jres));
return 0;
} }
#else #else
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k, int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res) { curve_point *res) {
point_multiply(curve, k, &curve->G, res); return point_multiply(curve, k, &curve->G, res);
} }
#endif #endif
@ -754,33 +764,43 @@ int ecdsa_sign_digest(const ecdsa_curve *curve, const uint8_t *priv_key,
return -1; return -1;
} }
void ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key, // returns 0 on success
uint8_t *pub_key) { int ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
curve_point R = {0}; curve_point R = {0};
bignum256 k = {0}; bignum256 k = {0};
bn_read_be(priv_key, &k); bn_read_be(priv_key, &k);
// compute k*G // compute k*G
scalar_multiply(curve, &k, &R); if (scalar_multiply(curve, &k, &R) != 0) {
memzero(&k, sizeof(k));
return 1;
}
pub_key[0] = 0x02 | (R.y.val[0] & 0x01); pub_key[0] = 0x02 | (R.y.val[0] & 0x01);
bn_write_be(&R.x, pub_key + 1); bn_write_be(&R.x, pub_key + 1);
memzero(&R, sizeof(R)); memzero(&R, sizeof(R));
memzero(&k, sizeof(k)); memzero(&k, sizeof(k));
return 0;
} }
void ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key, // returns 0 on success
uint8_t *pub_key) { int ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key) {
curve_point R = {0}; curve_point R = {0};
bignum256 k = {0}; bignum256 k = {0};
bn_read_be(priv_key, &k); bn_read_be(priv_key, &k);
// compute k*G // compute k*G
scalar_multiply(curve, &k, &R); if (scalar_multiply(curve, &k, &R) != 0) {
memzero(&k, sizeof(k));
return 1;
}
pub_key[0] = 0x04; pub_key[0] = 0x04;
bn_write_be(&R.x, pub_key + 1); bn_write_be(&R.x, pub_key + 1);
bn_write_be(&R.y, pub_key + 33); bn_write_be(&R.y, pub_key + 33);
memzero(&R, sizeof(R)); memzero(&R, sizeof(R));
memzero(&k, sizeof(k)); memzero(&k, sizeof(k));
return 0;
} }
int ecdsa_uncompress_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key, int ecdsa_uncompress_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key,

@ -65,14 +65,14 @@ void point_copy(const curve_point *cp1, curve_point *cp2);
void point_add(const ecdsa_curve *curve, const curve_point *cp1, void point_add(const ecdsa_curve *curve, const curve_point *cp1,
curve_point *cp2); curve_point *cp2);
void point_double(const ecdsa_curve *curve, curve_point *cp); void point_double(const ecdsa_curve *curve, curve_point *cp);
void point_multiply(const ecdsa_curve *curve, const bignum256 *k, int point_multiply(const ecdsa_curve *curve, const bignum256 *k,
const curve_point *p, curve_point *res); const curve_point *p, curve_point *res);
void point_set_infinity(curve_point *p); void point_set_infinity(curve_point *p);
int point_is_infinity(const curve_point *p); int point_is_infinity(const curve_point *p);
int point_is_equal(const curve_point *p, const curve_point *q); int point_is_equal(const curve_point *p, const curve_point *q);
int point_is_negative_of(const curve_point *p, const curve_point *q); int point_is_negative_of(const curve_point *p, const curve_point *q);
void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k, int scalar_multiply(const ecdsa_curve *curve, const bignum256 *k,
curve_point *res); curve_point *res);
int ecdh_multiply(const ecdsa_curve *curve, const uint8_t *priv_key, int ecdh_multiply(const ecdsa_curve *curve, const uint8_t *priv_key,
const uint8_t *pub_key, uint8_t *session_key); const uint8_t *pub_key, uint8_t *session_key);
void compress_coords(const curve_point *cp, uint8_t *compressed); void compress_coords(const curve_point *cp, uint8_t *compressed);
@ -88,10 +88,10 @@ int ecdsa_sign(const ecdsa_curve *curve, HasherType hasher_sign,
int ecdsa_sign_digest(const ecdsa_curve *curve, const uint8_t *priv_key, int ecdsa_sign_digest(const ecdsa_curve *curve, const uint8_t *priv_key,
const uint8_t *digest, uint8_t *sig, uint8_t *pby, const uint8_t *digest, uint8_t *sig, uint8_t *pby,
int (*is_canonical)(uint8_t by, uint8_t sig[64])); int (*is_canonical)(uint8_t by, uint8_t sig[64]));
void ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key, int ecdsa_get_public_key33(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key); uint8_t *pub_key);
void ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key, int ecdsa_get_public_key65(const ecdsa_curve *curve, const uint8_t *priv_key,
uint8_t *pub_key); uint8_t *pub_key);
void ecdsa_get_pubkeyhash(const uint8_t *pub_key, HasherType hasher_pubkey, void ecdsa_get_pubkeyhash(const uint8_t *pub_key, HasherType hasher_pubkey,
uint8_t *pubkeyhash); uint8_t *pubkeyhash);
void ecdsa_get_address_raw(const uint8_t *pub_key, uint32_t version, void ecdsa_get_address_raw(const uint8_t *pub_key, uint32_t version,

@ -3429,19 +3429,20 @@ START_TEST(test_bip32_decred_vector_2) {
} }
END_TEST END_TEST
static void test_ecdsa_get_public_key33_helper( static void test_ecdsa_get_public_key33_helper(int (*ecdsa_get_public_key33_fn)(
void (*ecdsa_get_public_key33_fn)(const ecdsa_curve *, const uint8_t *, const ecdsa_curve *, const uint8_t *, uint8_t *)) {
uint8_t *)) { uint8_t privkey[32] = {0};
uint8_t privkey[32]; uint8_t pubkey[65] = {0};
uint8_t pubkey[65];
const ecdsa_curve *curve = &secp256k1; const ecdsa_curve *curve = &secp256k1;
int res = 0;
memcpy( memcpy(
privkey, privkey,
fromhex( fromhex(
"c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"), "c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"),
32); 32);
ecdsa_get_public_key33_fn(curve, privkey, pubkey); res = ecdsa_get_public_key33_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq( ck_assert_mem_eq(
pubkey, pubkey,
fromhex( fromhex(
@ -3453,7 +3454,8 @@ static void test_ecdsa_get_public_key33_helper(
fromhex( fromhex(
"3b90a4de80fb00d77795762c389d1279d4b4ab5992ae3cde6bc12ca63116f74c"), "3b90a4de80fb00d77795762c389d1279d4b4ab5992ae3cde6bc12ca63116f74c"),
32); 32);
ecdsa_get_public_key33_fn(curve, privkey, pubkey); res = ecdsa_get_public_key33_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq( ck_assert_mem_eq(
pubkey, pubkey,
fromhex( fromhex(
@ -3471,19 +3473,20 @@ START_TEST(test_zkp_ecdsa_get_public_key33) {
} }
END_TEST END_TEST
static void test_ecdsa_get_public_key65_helper( static void test_ecdsa_get_public_key65_helper(int (*ecdsa_get_public_key65_fn)(
void (*ecdsa_get_public_key65_fn)(const ecdsa_curve *, const uint8_t *, const ecdsa_curve *, const uint8_t *, uint8_t *)) {
uint8_t *)) { uint8_t privkey[32] = {0};
uint8_t privkey[32]; uint8_t pubkey[65] = {0};
uint8_t pubkey[65];
const ecdsa_curve *curve = &secp256k1; const ecdsa_curve *curve = &secp256k1;
int res = 0;
memcpy( memcpy(
privkey, privkey,
fromhex( fromhex(
"c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"), "c46f5b217f04ff28886a89d3c762ed84e5fa318d1c9a635d541131e69f1f49f5"),
32); 32);
ecdsa_get_public_key65_fn(curve, privkey, pubkey); res = ecdsa_get_public_key65_fn(curve, privkey, pubkey);
ck_assert_int_eq(res, 0);
ck_assert_mem_eq( ck_assert_mem_eq(
pubkey, pubkey,
fromhex( fromhex(
@ -6337,11 +6340,11 @@ static void test_point_mult_curve(const ecdsa_curve *curve) {
/* test distributivity: (a + b)P = aP + bP */ /* test distributivity: (a + b)P = aP + bP */
bn_mod(&a, &curve->order); bn_mod(&a, &curve->order);
bn_mod(&b, &curve->order); bn_mod(&b, &curve->order);
point_multiply(curve, &a, &p, &p1); ck_assert_int_eq(point_multiply(curve, &a, &p, &p1), 0);
point_multiply(curve, &b, &p, &p2); ck_assert_int_eq(point_multiply(curve, &b, &p, &p2), 0);
bn_addmod(&a, &b, &curve->order); bn_addmod(&a, &b, &curve->order);
bn_mod(&a, &curve->order); bn_mod(&a, &curve->order);
point_multiply(curve, &a, &p, &p3); ck_assert_int_eq(point_multiply(curve, &a, &p, &p3), 0);
point_add(curve, &p1, &p2); point_add(curve, &p1, &p2);
ck_assert_mem_eq(&p2, &p3, sizeof(curve_point)); ck_assert_mem_eq(&p2, &p3, sizeof(curve_point));
// new "random" numbers and a "random" point // new "random" numbers and a "random" point
@ -6368,17 +6371,17 @@ static void test_scalar_point_mult_curve(const ecdsa_curve *curve) {
*/ */
bn_mod(&a, &curve->order); bn_mod(&a, &curve->order);
bn_mod(&b, &curve->order); bn_mod(&b, &curve->order);
scalar_multiply(curve, &a, &p1); ck_assert_int_eq(scalar_multiply(curve, &a, &p1), 0);
point_multiply(curve, &b, &p1, &p1); ck_assert_int_eq(point_multiply(curve, &b, &p1, &p1), 0);
scalar_multiply(curve, &b, &p2); ck_assert_int_eq(scalar_multiply(curve, &b, &p2), 0);
point_multiply(curve, &a, &p2, &p2); ck_assert_int_eq(point_multiply(curve, &a, &p2, &p2), 0);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point)); ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));
bn_multiply(&a, &b, &curve->order); bn_multiply(&a, &b, &curve->order);
bn_mod(&b, &curve->order); bn_mod(&b, &curve->order);
scalar_multiply(curve, &b, &p2); ck_assert_int_eq(scalar_multiply(curve, &b, &p2), 0);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point)); ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));

@ -44,9 +44,10 @@ static bool is_zero_digest(const uint8_t *digest) {
// curve has to be &secp256k1 // curve has to be &secp256k1
// private_key_bytes has 32 bytes // private_key_bytes has 32 bytes
// public_key_bytes has 33 bytes // public_key_bytes has 33 bytes
void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve, // returns 0 on success
const uint8_t *private_key_bytes, int zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
uint8_t *public_key_bytes) { const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
assert(curve == &secp256k1); assert(curve == &secp256k1);
int result = 0; int result = 0;
@ -75,16 +76,17 @@ void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
} }
memzero(&public_key, sizeof(public_key)); memzero(&public_key, sizeof(public_key));
assert(result == 0); return result;
} }
// ECDSA uncompressed public key derivation // ECDSA uncompressed public key derivation
// curve has to be &secp256k1 // curve has to be &secp256k1
// private_key_bytes has 32 bytes // private_key_bytes has 32 bytes
// public_key_bytes has 65 bytes // public_key_bytes has 65 bytes
void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve, // returns 0 on success
const uint8_t *private_key_bytes, int zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
uint8_t *public_key_bytes) { const uint8_t *private_key_bytes,
uint8_t *public_key_bytes) {
assert(curve == &secp256k1); assert(curve == &secp256k1);
int result = 0; int result = 0;
@ -113,7 +115,7 @@ void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
} }
memzero(&public_key, sizeof(public_key)); memzero(&public_key, sizeof(public_key));
assert(result == 0); return result;
} }
// ECDSA signing // ECDSA signing

@ -5,12 +5,12 @@
#include "hasher.h" #include "hasher.h"
void zkp_ecdsa_get_public_key33(const ecdsa_curve *curve, int zkp_ecdsa_get_public_key33(const ecdsa_curve *curve,
const uint8_t *private_key_bytes, const uint8_t *private_key_bytes,
uint8_t *public_key_bytes); uint8_t *public_key_bytes);
void zkp_ecdsa_get_public_key65(const ecdsa_curve *curve, int zkp_ecdsa_get_public_key65(const ecdsa_curve *curve,
const uint8_t *private_key_bytes, const uint8_t *private_key_bytes,
uint8_t *public_key_bytes); uint8_t *public_key_bytes);
int zkp_ecdsa_sign_digest(const ecdsa_curve *curve, int zkp_ecdsa_sign_digest(const ecdsa_curve *curve,
const uint8_t *private_key_bytes, const uint8_t *private_key_bytes,
const uint8_t *digest, uint8_t *signature_bytes, const uint8_t *digest, uint8_t *signature_bytes,

Loading…
Cancel
Save