diff --git a/bignum.c b/bignum.c index 2ba8afe5e6..05a867ab55 100644 --- a/bignum.c +++ b/bignum.c @@ -229,27 +229,10 @@ void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime) // assumes x partly reduced, guarantees x fully reduced. void bn_mod(bignum256 *x, const bignum256 *prime) { - int i = 8; - uint32_t temp; - // compare numbers - while (i >= 0 && prime->val[i] == x->val[i]) i--; - // if equal - if (i == -1) { - // set x to zero - bn_zero(x); - } else { - // if x is greater - if (x->val[i] > prime->val[i]) { - // substract p from x - temp = 0x40000000u; - for (i = 0; i < 9; i++) { - temp += x->val[i] - prime->val[i]; - x->val[i] = temp & 0x3FFFFFFF; - temp >>= 30; - temp += 0x3FFFFFFFu; - } - } - } + const int flag = bn_is_less(x, prime); // x < prime + bignum256 temp; + bn_subtract(x, prime, &temp); // temp = x - prime + bn_cmov(x, flag, x, &temp); } // auxiliary function for multiplication. diff --git a/test_curves.py b/test_curves.py index 67d1808093..c511bbc97f 100755 --- a/test_curves.py +++ b/test_curves.py @@ -261,6 +261,13 @@ def test_mod(curve, r): lib.bn_mod(y, int2bn(curve.p)) assert bn2int(y) == x % curve.p +def test_mod_specific(curve): + p = curve.p + for x in [0, 1, 2, p - 2, p - 1, p, p + 1, p + 2, 2*p - 2, 2*p - 1]: + y = int2bn(x) + lib.bn_mod(y, int2bn(curve.p)) + assert bn2int(y) == x % p + POINT = BIGNUM * 2 to_POINT = lambda p: POINT(int2bn(p.x()), int2bn(p.y())) from_POINT = lambda p: (bn2int(p[0]), bn2int(p[1]))