diff --git a/ecdsa.c b/ecdsa.c index 0b012ecb0e..85a3fab241 100644 --- a/ecdsa.c +++ b/ecdsa.c @@ -195,53 +195,101 @@ int point_is_negative_of(const curve_point *p, const curve_point *q) #if USE_PRECOMPUTED_CP +// Negate a (modulo prime) if cond is 0xffffffff, keep it if cond is 0. +// The timing of this function does not depend on cond. +static void conditional_negate(uint32_t cond, bignum256 *a, const bignum256 *prime) +{ + int j; + uint32_t tmp = 1; + for (j = 0; j < 8; j++) { + tmp += 0x3fffffff + prime->val[j] - a->val[j]; + a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond); + tmp >>= 30; + } + tmp += 0x3fffffff + prime->val[j] - a->val[j]; + a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond); +} + // res = k * G +// k must be a normalized number with 0 <= k < order256k1 void scalar_multiply(const bignum256 *k, curve_point *res) { - int i, j; - // result is zero assert (bn_is_less(k, &order256k1)); - bignum256 a = *k; - int sign = (a.val[0] & 1) ^ 1; - uint32_t lowbits; - // make number odd - if (sign) { - bn_subtract(&order256k1, &a, &a); - } - a.val[8] |= 0x10000; - assert((a.val[0] & 1) != 0); - assert((a.val[8] & 0x10000) != 0); - // now compute res = a *G step by step. - // initial res + int i, j; + bignum256 a; + uint32_t is_even = (k->val[0] & 1) - 1; + uint32_t lowbits; + + // is_even = 0xffffffff if k is even, 0 otherwise. + + // add 2^256. + // make number odd: subtract order256k1 if even + uint32_t tmp = 1; + uint32_t is_non_zero = 0; + for (j = 0; j < 8; j++) { + is_non_zero |= k->val[j]; + tmp += 0x3fffffff + k->val[j] - (order256k1.val[j] & is_even); + a.val[j] = tmp & 0x3fffffff; + tmp >>= 30; + } + is_non_zero |= k->val[j]; + a.val[j] = tmp + 0xffff + k->val[j] - (order256k1.val[j] & is_even); + assert((a.val[0] & 1) != 0); + + // special case 0*G: just return zero. We don't care about constant time. + if (!is_non_zero) { + point_set_infinity(res); + return; + } + + // Now a = k + 2^256 (mod order256k1) and a is odd. + // + // The idea is to bring the new a into the form. + // sum_{i=0..64} a[i] 16^i, where |a[i]| < 16 and a[i] is odd. + // a[0] is odd, since a is odd. If a[i] would be even, we can + // add 1 to it and subtract 16 from a[i-1]. Afterwards, + // a[64] = 1, which is the 2^256 that we added before. + // + // Since k = a - 2^256 (mod order256k1), we can compute + // k*G = sum_{i=0..63} a[i] 16^i * G + // + // We have a big table secp256k1_cp that stores all possible + // values of |a[i]| 16^i * G. + // secp256k1_cp[i][j] = (2*j+1) * 16^i * G + + // now compute res = sum_{i=0..63} a[i] * 16^i * G step by step. + // initial res = |a[0]| * G. Note that a[0] = a & 0xf if (a&0x10) != 0 + // and - (16 - (a & 0xf)) otherwise. We can compute this as + // ((a ^ (((a >> 4) & 1) - 1)) & 0xf) >> 1 + // since a is odd. lowbits = a.val[0] & ((1 << 5) - 1); lowbits ^= (lowbits >> 4) - 1; lowbits &= 15; *res = secp256k1_cp[0][lowbits >> 1]; for (i = 1; i < 64; i ++) { - // invariant res = abs((a % 2*16^i) - 16^i) * G + // invariant res = sign(a[i-1]) sum_{j=0..i-1} (a[j] * 16^j * G) + // Note that sign(a[i-1] + // shift a by 4 places. for (j = 0; j < 8; j++) { a.val[j] = (a.val[j] >> 4) | ((a.val[j + 1] & 0xf) << 26); } a.val[j] >>= 4; + // a = old(a)>>(4*i) + // a is even iff sign(a[i-1]) = -1 lowbits = a.val[0] & ((1 << 5) - 1); lowbits ^= (lowbits >> 4) - 1; lowbits &= 15; - if ((lowbits & 1) == 0) { - // negate last result to make signs of this round and the - // last round equal. - bn_subtract(&prime256k1, &res->y, &res->y); - } + // negate last result to make signs of this round and the + // last round equal. + conditional_negate((lowbits & 1) - 1, &res->y, &prime256k1); // add odd factor point_add(&secp256k1_cp[i][lowbits >> 1], res); } - if (sign) { - // negate - bn_subtract(&prime256k1, &res->y, &res->y); - } + conditional_negate(((a.val[0] >> 4) & 1) - 1, &res->y, &prime256k1); } #else