diff --git a/bignum.c b/bignum.c index 02060ad3ec..26cb93814e 100644 --- a/bignum.c +++ b/bignum.c @@ -139,6 +139,26 @@ void bn_rshift(bignum256 *a) a->val[8] >>= 1; } +// multiply x by 3/2 modulo prime. +// assumes x < 2*prime, +// guarantees x < 4*prime on exit. +void bn_mult_3_2(bignum256 * x, const bignum256 *prime) +{ + int j; + uint32_t xodd = -(x->val[0] & 1); + // compute x = 3*x/2 mod prime + // if x is odd compute (3*x+prime)/2 + uint32_t tmp1 = (3*x->val[0] + (prime->val[0] & xodd)) >> 1; + for (j = 0; j < 8; j++) { + uint32_t tmp2 = (3*x->val[j+1] + (prime->val[j+1] & xodd)); + tmp1 += (tmp2 & 1) << 29; + x->val[j] = tmp1 & 0x3fffffff; + tmp1 >>= 30; + tmp1 += tmp2 >> 1; + } + x->val[8] = tmp1; +} + // assumes x < 2*prime, result < prime void bn_mod(bignum256 *x, const bignum256 *prime) { diff --git a/bignum.h b/bignum.h index 97471b7b1b..fa733cd06b 100644 --- a/bignum.h +++ b/bignum.h @@ -57,6 +57,8 @@ void bn_lshift(bignum256 *a); void bn_rshift(bignum256 *a); +void bn_mult_3_2(bignum256 *x, const bignum256 *prime); + void bn_mod(bignum256 *x, const bignum256 *prime); void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime); diff --git a/ecdsa.c b/ecdsa.c index fcb0c3a641..6f076f279d 100644 --- a/ecdsa.c +++ b/ecdsa.c @@ -37,8 +37,7 @@ // Set cp2 = cp1 void point_copy(const curve_point *cp1, curve_point *cp2) { - memcpy(&(cp2->x), &(cp1->x), sizeof(bignum256)); - memcpy(&(cp2->y), &(cp1->y), sizeof(bignum256)); + *cp2 = *cp1; } // cp2 = cp1 + cp2 @@ -68,7 +67,9 @@ void point_add(const curve_point *cp1, curve_point *cp2) bn_inverse(&inv, &prime256k1); bn_subtractmod(&(cp2->y), &(cp1->y), &lambda, &prime256k1); bn_multiply(&inv, &lambda, &prime256k1); - memcpy(&xr, &lambda, sizeof(bignum256)); + + // xr = lambda^2 - x1 - x2 + xr = lambda; bn_multiply(&xr, &xr, &prime256k1); temp = 1; for (i = 0; i < 9; i++) { @@ -77,16 +78,17 @@ void point_add(const curve_point *cp1, curve_point *cp2) temp >>= 30; } bn_fast_mod(&xr, &prime256k1); + bn_mod(&xr, &prime256k1); + + // yr = lambda (x1 - xr) - y1 bn_subtractmod(&(cp1->x), &xr, &yr, &prime256k1); - // no need to fast_mod here - // bn_fast_mod(&yr); bn_multiply(&lambda, &yr, &prime256k1); bn_subtractmod(&yr, &(cp1->y), &yr, &prime256k1); bn_fast_mod(&yr, &prime256k1); - memcpy(&(cp2->x), &xr, sizeof(bignum256)); - memcpy(&(cp2->y), &yr, sizeof(bignum256)); - bn_mod(&(cp2->x), &prime256k1); - bn_mod(&(cp2->y), &prime256k1); + bn_mod(&yr, &prime256k1); + + cp2->x = xr; + cp2->y = yr; } // cp = cp + cp @@ -94,7 +96,7 @@ void point_double(curve_point *cp) { int i; uint32_t temp; - bignum256 lambda, inverse_y, xr, yr; + bignum256 lambda, xr, yr; if (point_is_infinity(cp)) { return; @@ -104,13 +106,15 @@ void point_double(curve_point *cp) return; } - memcpy(&inverse_y, &(cp->y), sizeof(bignum256)); - bn_inverse(&inverse_y, &prime256k1); - memcpy(&lambda, &three_over_two256k1, sizeof(bignum256)); - bn_multiply(&inverse_y, &lambda, &prime256k1); - bn_multiply(&(cp->x), &lambda, &prime256k1); - bn_multiply(&(cp->x), &lambda, &prime256k1); - memcpy(&xr, &lambda, sizeof(bignum256)); + // lambda = 3/2 x^2 / y + lambda = cp->y; + bn_inverse(&lambda, &prime256k1); + bn_multiply(&cp->x, &lambda, &prime256k1); + bn_multiply(&cp->x, &lambda, &prime256k1); + bn_mult_3_2(&lambda, &prime256k1); + + // xr = lambda^2 - 2*x + xr = lambda; bn_multiply(&xr, &xr, &prime256k1); temp = 1; for (i = 0; i < 9; i++) { @@ -119,16 +123,17 @@ void point_double(curve_point *cp) temp >>= 30; } bn_fast_mod(&xr, &prime256k1); + bn_mod(&xr, &prime256k1); + + // yr = lambda (x - xr) - y bn_subtractmod(&(cp->x), &xr, &yr, &prime256k1); - // no need to fast_mod here - // bn_fast_mod(&yr); bn_multiply(&lambda, &yr, &prime256k1); bn_subtractmod(&yr, &(cp->y), &yr, &prime256k1); bn_fast_mod(&yr, &prime256k1); - memcpy(&(cp->x), &xr, sizeof(bignum256)); - memcpy(&(cp->y), &yr, sizeof(bignum256)); - bn_mod(&(cp->x), &prime256k1); - bn_mod(&(cp->y), &prime256k1); + bn_mod(&yr, &prime256k1); + + cp->x = xr; + cp->y = yr; } // set point to internal representation of point at infinity @@ -322,8 +327,7 @@ static void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2) static void point_jacobian_double(jacobian_curve_point *p) { bignum256 m, msq, ysq, xysq; int j; - uint32_t tmp1, tmp2; - uint32_t modd; + uint32_t tmp1; /* usual algorithm: * @@ -336,7 +340,7 @@ static void point_jacobian_double(jacobian_curve_point *p) { * Hence, * lambda = m / yz * - * With z3 = 2yz (the denominator of lambda) + * With z3 = yz (the denominator of lambda) * we get x3 = lambda^2*z3^2 - 2*x/z^2*z3^2 * = m^2 - 2*xy^2 * and y3 = (lambda * (x/z^2 - x3/z3^2) - y/z^3) * z3^3 @@ -352,18 +356,7 @@ static void point_jacobian_double(jacobian_curve_point *p) { m = p->x; bn_multiply(&m, &m, &prime256k1); - modd = -(m.val[0] & 1); - // compute m = 3*m/2 mod prime - // if m is odd compute (3*m+prime)/2 - tmp1 = (3*m.val[0] + (prime256k1.val[0] & modd)) >> 1; - for (j = 0; j < 8; j++) { - tmp2 = (3*m.val[j+1] + (prime256k1.val[j+1] & modd)); - tmp1 += (tmp2 & 1) << 29; - m.val[j] = tmp1 & 0x3fffffff; - tmp1 >>= 30; - tmp1 += tmp2 >> 1; - } - m.val[8] = tmp1; + bn_mult_3_2(&m, &prime256k1); // msq = m^2 msq = m; @@ -374,6 +367,8 @@ static void point_jacobian_double(jacobian_curve_point *p) { // xysq = xy^2 xysq = p->x; bn_multiply(&ysq, &xysq, &prime256k1); + + // z3 = yz bn_multiply(&p->y, &p->z, &prime256k1); bn_mod(&p->z, &prime256k1); @@ -387,7 +382,7 @@ static void point_jacobian_double(jacobian_curve_point *p) { bn_fast_mod(&p->x, &prime256k1); bn_mod(&p->x, &prime256k1); - // y = m*(xy^2 - x3) - y^4 + // y3 = m*(xy^2 - x3) - y^4 bn_subtractmod(&xysq, &p->x, &p->y, &prime256k1); bn_multiply(&m, &p->y, &prime256k1); bn_multiply(&ysq, &ysq, &prime256k1);