diff --git a/bignum.c b/bignum.c index b243d0280a..45778856bb 100644 --- a/bignum.c +++ b/bignum.c @@ -626,14 +626,15 @@ void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime) { bn_mod(a, prime); } -// res = a - b -// b < 2*prime; result not normalized -void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res) +// res = a - b mod prime. More exactly res = a + (2*prime - b). +// precondition: 0 <= b < 2*prime, 0 <= a < prime +// res < 3*prime +void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime) { int i; uint32_t temp = 0; for (i = 0; i < 9; i++) { - temp += a->val[i] + 2u * prime256k1.val[i] - b->val[i]; + temp += a->val[i] + 2u * prime->val[i] - b->val[i]; res->val[i] = temp & 0x3FFFFFFF; temp >>= 30; } diff --git a/bignum.h b/bignum.h index de71fba0eb..97471b7b1b 100644 --- a/bignum.h +++ b/bignum.h @@ -73,7 +73,7 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime); void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime); -void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res); +void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime); void bn_subtract(const bignum256 *a, const bignum256 *b, bignum256 *res); diff --git a/ecdsa.c b/ecdsa.c index 0c5b5a625b..7ecb44cac6 100644 --- a/ecdsa.c +++ b/ecdsa.c @@ -63,24 +63,24 @@ void point_add(const curve_point *cp1, curve_point *cp2) return; } - bn_subtractmod(&(cp2->x), &(cp1->x), &inv); + bn_subtractmod(&(cp2->x), &(cp1->x), &inv, &prime256k1); bn_inverse(&inv, &prime256k1); - bn_subtractmod(&(cp2->y), &(cp1->y), &lambda); + bn_subtractmod(&(cp2->y), &(cp1->y), &lambda, &prime256k1); bn_multiply(&inv, &lambda, &prime256k1); memcpy(&xr, &lambda, sizeof(bignum256)); bn_multiply(&xr, &xr, &prime256k1); - temp = 0; + temp = 1; for (i = 0; i < 9; i++) { - temp += xr.val[i] + 3u * prime256k1.val[i] - cp1->x.val[i] - cp2->x.val[i]; + temp += 0x3FFFFFFF + xr.val[i] + 2u * prime256k1.val[i] - cp1->x.val[i] - cp2->x.val[i]; xr.val[i] = temp & 0x3FFFFFFF; temp >>= 30; } bn_fast_mod(&xr, &prime256k1); - bn_subtractmod(&(cp1->x), &xr, &yr); + bn_subtractmod(&(cp1->x), &xr, &yr, &prime256k1); // no need to fast_mod here // bn_fast_mod(&yr); bn_multiply(&lambda, &yr, &prime256k1); - bn_subtractmod(&yr, &(cp1->y), &yr); + bn_subtractmod(&yr, &(cp1->y), &yr, &prime256k1); bn_fast_mod(&yr, &prime256k1); memcpy(&(cp2->x), &xr, sizeof(bignum256)); memcpy(&(cp2->y), &yr, sizeof(bignum256)); @@ -111,18 +111,18 @@ void point_double(curve_point *cp) bn_multiply(&(cp->x), &lambda, &prime256k1); memcpy(&xr, &lambda, sizeof(bignum256)); bn_multiply(&xr, &xr, &prime256k1); - temp = 0; + temp = 1; for (i = 0; i < 9; i++) { - temp += xr.val[i] + 3u * prime256k1.val[i] - 2u * cp->x.val[i]; + temp += 0x3FFFFFFF + xr.val[i] + 2u * (prime256k1.val[i] - cp->x.val[i]); xr.val[i] = temp & 0x3FFFFFFF; temp >>= 30; } bn_fast_mod(&xr, &prime256k1); - bn_subtractmod(&(cp->x), &xr, &yr); + bn_subtractmod(&(cp->x), &xr, &yr, &prime256k1); // no need to fast_mod here // bn_fast_mod(&yr); bn_multiply(&lambda, &yr, &prime256k1); - bn_subtractmod(&yr, &(cp->y), &yr); + bn_subtractmod(&yr, &(cp->y), &yr, &prime256k1); bn_fast_mod(&yr, &prime256k1); memcpy(&(cp->x), &xr, sizeof(bignum256)); memcpy(&(cp->y), &yr, sizeof(bignum256));