From 6ba4d288b06202cd4c1d55734001432e08a30de6 Mon Sep 17 00:00:00 2001 From: Jochen Hoenicke Date: Thu, 23 Jul 2015 13:00:54 -0700 Subject: [PATCH] Cleaned up bignum code 1. Fixed bn_multiply_step to handle small primes. 2. Removed many calls to bn_mod to prevent side-channel leakage. --- bignum.c | 43 ++++++++++++++++++++----------------------- bignum.h | 2 +- ecdsa.c | 15 +++------------ tests.c | 14 +++++++------- 4 files changed, 31 insertions(+), 43 deletions(-) diff --git a/bignum.c b/bignum.c index 7796a0e09e..f17db05c6d 100644 --- a/bignum.c +++ b/bignum.c @@ -172,9 +172,7 @@ void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime) for (j = 0; j < 9; j++) { x->val[j] = k * x->val[j]; } - bn_normalize(x); bn_fast_mod(x, prime); - bn_mod(x, prime); } // assumes x < 2*prime, result < prime @@ -233,31 +231,35 @@ void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t // let k = i-8. // invariants: // res[0..(i+1)] = k * x (mod prime) - // 0 <= res < 2^(30k + 256) * (2^30 + 1) + // 0 <= res < 2^(30k + 256) * (2^31) // estimate (res / prime) // coef = res / 2^(30k + 256) rounded down - // 0 <= coef <= 2^30 + // 0 <= coef < 2^31 // subtract (coef * 2^(30k) * prime) from res // note that we unrolled the first iteration uint32_t j; uint32_t coef = (res[i] >> 16) + (res[i + 1] << 14); - uint64_t temp = 0x1000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef; + uint64_t temp = 0x2000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef; + assert (coef < 0x80000000u); res[i - 8] = temp & 0x3FFFFFFF; for (j = 1; j < 9; j++) { temp >>= 30; - temp += 0xFFFFFFFC0000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef; + // Note: coeff * prime->val <= (2^31-1) * (2^30-1) + // Hence, this addition will not underflow. + temp += 0x1FFFFFFF80000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef; res[i - 8 + j] = temp & 0x3FFFFFFF; + // 0 <= temp < 2^61 } - temp >>= 30; - temp += 0xFFFFFFFC0000000ull + res[i - 8 + j]; - res[i - 8 + j] = temp & 0x3FFFFFFF; - // we rely on the fact that prime > 2^256 - 2^196 + temp >>= 30; + temp += 0x1FFFFFFF80000000ull + res[i - 8 + j]; + res[i - 8 + j] = temp & 0x3FFFFFFF; + // we rely on the fact that prime > 2^256 - 2^224 // res = oldres - coef*2^(30k) * prime; // and // coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256) // Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime)) - // Since coef * (2^256 - prime) < 2^226, we get - // 0 <= res < 2^(30k + 226) (2^30 + 1) + // Since coef * (2^256 - prime) < 2^256, we get + // 0 <= res < 2^(30k + 226) (2^31) // Thus the invariant holds again. } @@ -269,9 +271,8 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[18], const bignum256 *prime) // 0 <= res < 2^526. // compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime for (i = 16; i >= 8; i--) { - bn_multiply_reduce_step(res, prime, i); - bn_multiply_reduce_step(res, prime, i); // apply twice, as a hack for NIST256P1 prime. - assert(res[i + 1] == 0); + bn_multiply_reduce_step(res, prime, i); + assert(res[i + 1] == 0); } // store the result for (i = 0; i < 9; i++) { @@ -294,7 +295,7 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) } // input x can be any normalized number that fits (0 <= x < 2^270). -// prime must be between (2^256 - 2^196) and 2^256 +// prime must be between (2^256 - 2^224) and 2^256 // result is smaller than 2*prime void bn_fast_mod(bignum256 *x, const bignum256 *prime) { @@ -305,11 +306,11 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime) coef = x->val[8] >> 16; // substract (coef * prime) from x // note that we unrolled the first iteration - temp = 0x1000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef; + temp = 0x2000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef; x->val[0] = temp & 0x3FFFFFFF; for (j = 1; j < 9; j++) { temp >>= 30; - temp += 0xFFFFFFFC0000000ull + x->val[j] - prime->val[j] * (uint64_t)coef; + temp += 0x1FFFFFFF80000000ull + x->val[j] - prime->val[j] * (uint64_t)coef; x->val[j] = temp & 0x3FFFFFFF; } } @@ -679,16 +680,12 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime) for (i = 0; i < 9; i++) { a->val[i] += b->val[i]; } - bn_normalize(a); bn_fast_mod(a, prime); - bn_mod(a, prime); } -void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime) { +void bn_addi(bignum256 *a, uint32_t b) { a->val[0] += b; bn_normalize(a); - bn_fast_mod(a, prime); - bn_mod(a, prime); } // res = a - b mod prime. More exactly res = a + (2*prime - b). diff --git a/bignum.h b/bignum.h index 4774ca7d3a..8851305aac 100644 --- a/bignum.h +++ b/bignum.h @@ -75,7 +75,7 @@ void bn_normalize(bignum256 *a); void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime); -void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime); +void bn_addi(bignum256 *a, uint32_t b); void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime); diff --git a/ecdsa.c b/ecdsa.c index 380ddaabf2..3f2e4d429d 100644 --- a/ecdsa.c +++ b/ecdsa.c @@ -207,13 +207,10 @@ void curve_to_jacobian(const curve_point *p, jacobian_curve_point *jp, const big bn_multiply(&p->x, &jp->x, prime); bn_multiply(&p->y, &jp->y, prime); - bn_mod(&jp->x, prime); - bn_mod(&jp->y, prime); } void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const bignum256 *prime) { p->y = jp->z; - bn_mod(&p->y, prime); bn_inverse(&p->y, prime); // p->y = z^-1 p->x = p->y; @@ -298,21 +295,18 @@ void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const b // z3 = h*z2 bn_multiply(&h, &p2->z, prime); - bn_mod(&p2->z, prime); // x3 = r^2 - h^3 - 2h^2x2 bn_addmod(&hcb, &hsqx2, prime); bn_addmod(&hcb, &hsqx2, prime); bn_subtractmod(&rsq, &hcb, &p2->x, prime); bn_fast_mod(&p2->x, prime); - bn_mod(&p2->x, prime); // y3 = r*(h^2x2 - x3) - y2*h^3 bn_subtractmod(&hsqx2, &p2->x, &p2->y, prime); bn_multiply(&r, &p2->y, prime); bn_subtractmod(&p2->y, &hcby2, &p2->y, prime); bn_fast_mod(&p2->y, prime); - bn_mod(&p2->y, prime); } void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) { @@ -366,15 +360,13 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) { // z3 = yz bn_multiply(&p->y, &p->z, prime); - bn_mod(&p->z, prime); // x3 = m^2 - 2*xy^2 p->x = xysq; - bn_mod(&p->x, prime); bn_lshift(&p->x); + bn_fast_mod(&p->x, prime); bn_subtractmod(&msq, &p->x, &p->x, prime); bn_fast_mod(&p->x, prime); - bn_mod(&p->x, prime); // y3 = m*(xy^2 - x3) - y^4 bn_subtractmod(&xysq, &p->x, &p->y, prime); @@ -382,7 +374,6 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) { bn_multiply(&ysq, &ysq, prime); bn_subtractmod(&p->y, &ysq, &p->y, prime); bn_fast_mod(&p->y, prime); - bn_mod(&p->y, prime); } // res = k * p @@ -835,7 +826,7 @@ void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x memcpy(y, x, sizeof(bignum256)); // y is x bn_multiply(x, y, &curve->prime); // y is x^2 bn_multiply(x, y, &curve->prime); // y is x^3 - bn_addmodi(y, 7, &curve->prime); // y is x^3 + 7 + bn_addi(y, 7); // y is x^3 + 7 bn_sqrt(y, &curve->prime); // y = sqrt(y) if ((odd & 0x01) != (y->val[0] & 1)) { bn_subtract(&curve->prime, y, y); // y = -y @@ -885,7 +876,7 @@ int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub) // x^3 + b bn_multiply(&(pub->x), &x_3_b, &curve->prime); bn_multiply(&(pub->x), &x_3_b, &curve->prime); - bn_addmodi(&x_3_b, 7, &curve->prime); + bn_addi(&x_3_b, 7); if (!bn_is_equal(&x_3_b, &y_2)) { return 0; diff --git a/tests.c b/tests.c index 511993df16..4401c31c11 100644 --- a/tests.c +++ b/tests.c @@ -40,10 +40,10 @@ #include "secp256k1.h" #define CURVE (&secp256k1) -#define prime256k1 (secp256k1.prime) -#define G256k1 (secp256k1.G) -#define order256k1 (secp256k1.order) -#define secp256k1_cp (secp256k1.cp) +#define prime256k1 (CURVE->prime) +#define G256k1 (CURVE->G) +#define order256k1 (CURVE->order) +#define secp256k1_cp (CURVE->cp) uint8_t *fromhex(const char *str) { @@ -1231,7 +1231,7 @@ START_TEST(test_secp256k1_cp) { // increment by one and test again p1 = p; point_add(CURVE, &G256k1, &p1); - bn_addmodi(&a, 1, &order256k1); + bn_addi(&a, 1); scalar_multiply(CURVE, &a, &p); ck_assert_mem_eq(&p, &p1, sizeof(curve_point)); bn_zero(&p.y); // test that point_multiply CURVE, is not a noop @@ -1254,7 +1254,7 @@ START_TEST(test_mult_border_cases) { point_multiply(CURVE, &a, &G256k1, &p); ck_assert(point_is_infinity(&p)); - bn_addmodi(&a, 1, &order256k1); // a == 1 + bn_addi(&a, 1); // a == 1 scalar_multiply(CURVE, &a, &p); ck_assert_mem_eq(&p, &G256k1, sizeof(curve_point)); point_multiply(CURVE, &a, &G256k1, &p); @@ -1269,7 +1269,7 @@ START_TEST(test_mult_border_cases) { ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); bn_subtract(&order256k1, &a, &a); - bn_addmodi(&a, 1, &order256k1); // a == 2 + bn_addi(&a, 1); // a == 2 expected = G256k1; point_add(CURVE, &expected, &expected); scalar_multiply(CURVE, &a, &p);