diff --git a/bignum.c b/bignum.c index d94dc6fcc1..60846195aa 100644 --- a/bignum.c +++ b/bignum.c @@ -177,8 +177,11 @@ void bn_muli(bignum256 *a, uint32_t b) a->val[8] += t; } -// x = k * x -// both inputs and result may be bigger than prime but not bigger than 2 * prime +// Compute x := k * x (mod prime) +// both inputs must be smaller than 2 * prime. +// result is reduced to 0 <= x < 2 * prime +// This only works for primes between 2^256-2^196 and 2^256. +// this particular implementation accepts inputs up to 2^263 or 128*prime. void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) { int i, j; @@ -204,11 +207,21 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) temp >>= 30; } res[17] = temp; + // res = k * x is a normalized number (every limb < 2^30) + // 0 <= res < 2^526. // compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime for (i = 16; i >= 8; i--) { + // let k = i-8. + // invariants: + // res[0..(i+1)] = k * x (mod prime) + // 0 <= res < 2^(30k + 256) * (2^30 + 1) // estimate (res / prime) coef = (res[i] >> 16) + (res[i + 1] << 14); - // substract (coef * prime) from res + + // coef = res / 2^(30k + 256) rounded down + // 0 <= coef <= 2^30 + // subtract (coef * 2^(30k) * prime) from res + // note that we unrolled the first iteration temp = 0x1000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef; res[i - 8] = temp & 0x3FFFFFFF; for (j = 1; j < 9; j++) { @@ -216,6 +229,16 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) temp += 0xFFFFFFFC0000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef; res[i - 8 + j] = temp & 0x3FFFFFFF; } + // we don't clear res[i+1] but we never read it again. + + // we rely on the fact that prime > 2^256 - 2^196 + // res = oldres - coef*2^(30k) * prime; + // and + // coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256) + // Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime)) + // Since coef * (2^256 - prime) < 2^226, we get + // 0 <= res < 2^(30k + 226) (2^30 + 1) + // Thus the invariant holds again. } // store the result for (i = 0; i < 9; i++) { @@ -223,6 +246,8 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) } } +// input x can be any normalized number that fits (0 <= x < 2^270). +// prime must be between (2^256 - 2^196) and 2^256 // result is smaller than 2*prime void bn_fast_mod(bignum256 *x, const bignum256 *prime) { @@ -233,6 +258,7 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime) coef = x->val[8] >> 16; if (!coef) return; // substract (coef * prime) from x + // note that we unrolled the first iteration temp = 0x1000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef; x->val[0] = temp & 0x3FFFFFFF; for (j = 1; j < 9; j++) { @@ -246,16 +272,26 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime) // http://en.wikipedia.org/wiki/Quadratic_residue#Prime_or_prime_power_modulus void bn_sqrt(bignum256 *x, const bignum256 *prime) { + // this method compute x^1/2 = x^(prime+1)/4 uint32_t i, j, limb; bignum256 res, p; bn_zero(&res); res.val[0] = 1; + // compute p = (prime+1)/4 memcpy(&p, prime, sizeof(bignum256)); p.val[0] += 1; bn_rshift(&p); bn_rshift(&p); for (i = 0; i < 9; i++) { + // invariants: + // x = old(x)^(2^(i*30)) + // res = old(x)^(p % 2^(i*30)) + // get the i-th limb of prime - 2 limb = p.val[i]; for (j = 0; j < 30; j++) { + // invariants: + // x = old(x)^(2^(i*30+j)) + // res = old(x)^(p % 2^(i*30+j)) + // limb = (p % 2^(i*30+30)) / 2^(i*30+j) if (i == 8 && limb == 0) break; if (limb & 1) { bn_multiply(x, &res, prime); @@ -277,14 +313,24 @@ void bn_sqrt(bignum256 *x, const bignum256 *prime) // in field G_prime, small but slow void bn_inverse(bignum256 *x, const bignum256 *prime) { + // this method compute x^-1 = x^(prime-2) uint32_t i, j, limb; bignum256 res; bn_zero(&res); res.val[0] = 1; for (i = 0; i < 9; i++) { + // invariants: + // x = old(x)^(2^(i*30)) + // res = old(x)^((prime-2) % 2^(i*30)) + // get the i-th limb of prime - 2 limb = prime->val[i]; // this is not enough in general but fine for secp256k1 because prime->val[0] > 1 if (i == 0) limb -= 2; for (j = 0; j < 30; j++) { + // invariants: + // x = old(x)^(2^(i*30+j)) + // res = old(x)^((prime-2) % 2^(i*30+j)) + // limb = ((prime-2) % 2^(i*30+30)) / 2^(i*30+j) + // early abort when only zero bits follow if (i == 8 && limb == 0) break; if (limb & 1) { bn_multiply(x, &res, prime); @@ -300,14 +346,18 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) #else // in field G_prime, big but fast +// this algorithm is based on the Euklidean algorithm +// the result is smaller than 2*prime void bn_inverse(bignum256 *x, const bignum256 *prime) { int i, j, k, len1, len2, mask; uint8_t buf[32]; uint32_t u[8], v[8], s[9], r[10], temp32; uint64_t temp, temp2; + // reduce x modulo prime bn_fast_mod(x, prime); bn_mod(x, prime); + // convert x and prime it to 8x32 bit limb form bn_write_be(prime, buf); for (i = 0; i < 8; i++) { u[i] = read_be(buf + 28 - i * 4); @@ -321,59 +371,98 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) r[0] = 0; len2 = 1; k = 0; + // u = prime, v = x len1 = numlimbs(u,v) + // r = 0 , s = 1 len2 = numlimbs(r,s) + // k = 0 for (;;) { + // invariants: + // r,s,u,v >= 0 + // x*-r = u*2^k mod prime + // x*s = v*2^k mod prime + // u*s + v*r = prime + // floor(log2(u)) + floor(log2(v)) + k <= 510 + // max(u,v) <= 2^k + // gcd(u,v) = 1 + // len1 = numlimbs(u,v) + // len2 = numlimbs(r,s) + // + // first u,v are large and s,r small + // later u,v are small and s,r large + + // if (is_zero(v)) break; for (i = 0; i < len1; i++) { if (v[i]) break; } if (i == len1) break; + + // reduce u while it is even for (;;) { + // count up to 30 zero bits of u. for (i = 0; i < 30; i++) { if (u[0] & (1 << i)) break; } + // if u was odd break if (i == 0) break; + + // shift u right by i bits. mask = (1 << i) - 1; for (j = 0; j + 1 < len1; j++) { u[j] = (u[j] >> i) | ((u[j + 1] & mask) << (32 - i)); } u[j] = (u[j] >> i); + + // shift s left by i bits. mask = (1 << (32 - i)) - 1; s[len2] = s[len2 - 1] >> (32 - i); for (j = len2 - 1; j > 0; j--) { s[j] = (s[j - 1] >> (32 - i)) | ((s[j] & mask) << i); } s[0] = (s[0] & mask) << i; + // update len2 if necessary if (s[len2]) { r[len2] = 0; len2++; } + // add i bits to k. k += i; } + // reduce v while it is even for (;;) { + // count up to 30 zero bits of v. for (i = 0; i < 30; i++) { if (v[0] & (1 << i)) break; } + // if v was odd break if (i == 0) break; + + // shift v right by i bits. mask = (1 << i) - 1; for (j = 0; j + 1 < len1; j++) { v[j] = (v[j] >> i) | ((v[j + 1] & mask) << (32 - i)); } v[j] = (v[j] >> i); mask = (1 << (32 - i)) - 1; + // shift r left by i bits. r[len2] = r[len2 - 1] >> (32 - i); for (j = len2 - 1; j > 0; j--) { r[j] = (r[j - 1] >> (32 - i)) | ((r[j] & mask) << i); } r[0] = (r[0] & mask) << i; + // update len2 if necessary if (r[len2]) { s[len2] = 0; len2++; } + // add i bits to k. k += i; } + // invariant is reestablished. i = len1 - 1; while (i > 0 && u[i] == v[i]) i--; if (u[i] > v[i]) { + // u > v: + // u = (u - v)/2; temp = 0x100000000ull + u[0] - v[0]; u[0] = (temp >> 1) & 0x7FFFFFFF; temp >>= 32; @@ -384,6 +473,8 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) temp >>= 32; } temp = temp2 = 0; + // r += s; + // s += s; for (i = 0; i < len2; i++) { temp += s[i]; temp += r[i]; @@ -394,12 +485,19 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) temp >>= 32; temp2 >>= 32; } + // expand if necessary. if (temp != 0 || temp2 != 0) { r[len2] = temp; s[len2] = temp2; len2++; } + // note that + // u'2^(k+1) = (u - v) 2^k = x -(r + s) = x -r' mod prime + // v'2^(k+1) = 2*v 2^k = x (s + s) = x s' mod prime + // u's' + v'r' = (u-v)/2(2s) + v(r+s) = us + vr } else { + // v >= u: + // v = v - u; temp = 0x100000000ull + v[0] - u[0]; v[0] = (temp >> 1) & 0x7FFFFFFF; temp >>= 32; @@ -409,6 +507,8 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) v[i] = (temp >> 1) & 0x7FFFFFFF; temp >>= 32; } + // s = s + r + // r = r + r temp = temp2 = 0; for (i = 0; i < len2; i++) { temp += s[i]; @@ -425,11 +525,28 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) r[len2] = temp2; len2++; } + // note that + // u'2^(k+1) = 2*u 2^k = x -(r + r) = x -r' mod prime + // v'2^(k+1) = (v - u) 2^k = x (s + r) = x s' mod prime + // u's' + v'r' = u(r+s) + (v-u)/2(2r) = us + vr } + // adjust len1 if possible. if (u[len1 - 1] == 0 && v[len1 - 1] == 0) len1--; + // increase k k++; } + // In the last iteration just before the comparison and subtraction + // we had u=1, v=1, s+r = prime, k <= 510, 2^k > max(s,r) >= prime/2 + // hence 0 <= r < prime and 255 <= k <= 510. + // + // Afterwards r is doubled, k is incremented by 1. + // Hence 0 <= r < 2*prime and 256 <= k < 512. + // + // The invariants give us x*-r = 2^k mod prime, + // hence r = -2^k * x^-1 mod prime. + // We need to compute -r/2^k mod prime. + // convert r to bignum style j = r[0] >> 30; r[0] = r[0] & 0x3FFFFFFFu; for (i = 1; i < len2; i++) { @@ -441,6 +558,7 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) i++; for (; i < 9; i++) r[i] = 0; + // r = r mod prime, note that r<2*prime. i = 8; while (i > 0 && r[i] == prime->val[i]) i--; if (r[i] >= prime->val[i]) { @@ -451,26 +569,39 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) temp32 >>= 30; } } + // negate r: r = prime - r temp32 = 1; for (i = 0; i < 9; i++) { temp32 += 0x3FFFFFFF + prime->val[i] - r[i]; r[i] = temp32 & 0x3FFFFFFF; temp32 >>= 30; } + // now: r = 2^k * x^-1 mod prime + // compute r/2^k, 256 <= k < 511 int done = 0; #if USE_PRECOMPUTED_IV if (prime == &prime256k1) { for (j = 0; j < 9; j++) { x->val[j] = r[j]; } + // secp256k1_iv[k-256] = 2^-k mod prime bn_multiply(secp256k1_iv + k - 256, x, prime); + // bn_fast_mod is unnecessary as bn_multiply already + // guarantees x < 2*prime bn_fast_mod(x, prime); + // We don't guarantee x < prime! + // the slow variant and the slow case below guarantee + // this. done = 1; } #endif if (!done) { + // compute r = r/2^k mod prime for (j = 0; j < k; j++) { + // invariant: r = 2^(k-j) * x^-1 mod prime + // in each iteration divide r by 2 modulo prime. if (r[0] & 1) { + // r is odd; compute r = (prime + r)/2 temp32 = r[0] + prime->val[0]; r[0] = (temp32 >> 1) & 0x1FFFFFFF; temp32 >>= 30; @@ -481,12 +612,14 @@ void bn_inverse(bignum256 *x, const bignum256 *prime) temp32 >>= 30; } } else { + // r = r / 2 for (i = 0; i < 8; i++) { r[i] = (r[i] >> 1) | ((r[i + 1] & 1) << 29); } r[8] = r[8] >> 1; } } + // r = x^-1 mod prime, since j = k for (j = 0; j < 9; j++) { x->val[j] = r[j]; }