diff --git a/bignum.c b/bignum.c index 694df3c7c7..7796a0e09e 100644 --- a/bignum.c +++ b/bignum.c @@ -88,28 +88,32 @@ void bn_zero(bignum256 *a) int bn_is_zero(const bignum256 *a) { int i; + uint32_t result = 0; for (i = 0; i < 9; i++) { - if (a->val[i] != 0) return 0; + result |= a->val[i]; } - return 1; + return !result; } int bn_is_less(const bignum256 *a, const bignum256 *b) { int i; + uint32_t res1 = 0; + uint32_t res2 = 0; for (i = 8; i >= 0; i--) { - if (a->val[i] < b->val[i]) return 1; - if (a->val[i] > b->val[i]) return 0; + res1 = (res1 << 1) | (a->val[i] < b->val[i]); + res2 = (res2 << 1) | (a->val[i] > b->val[i]); } - return 0; + return res1 > res2; } int bn_is_equal(const bignum256 *a, const bignum256 *b) { int i; + uint32_t result = 0; for (i = 0; i < 9; i++) { - if (a->val[i] != b->val[i]) return 0; + result |= (a->val[i] ^ b->val[i]); } - return 1; + return !result; } int bn_bitlen(const bignum256 *a) { diff --git a/test_curves.py b/test_curves.py index 2ee15e90bd..57df5e946b 100755 --- a/test_curves.py +++ b/test_curves.py @@ -97,6 +97,48 @@ def test_inverse(curve, r): assert y == y_ +def test_is_less(curve, r): + x = r.randrange(0, curve.p) + y = r.randrange(0, curve.p) + x_ = int2bn(x) + y_ = int2bn(y) + + res = lib.bn_is_less(x_, y_) + assert res == (x < y) + + res = lib.bn_is_less(y_, x_) + assert res == (y < x) + + +def test_is_equal(curve, r): + x = r.randrange(0, curve.p) + y = r.randrange(0, curve.p) + x_ = int2bn(x) + y_ = int2bn(y) + + assert lib.bn_is_equal(x_, y_) == (x == y) + assert lib.bn_is_equal(x_, x_) == 1 + assert lib.bn_is_equal(y_, y_) == 1 + + +def test_is_zero(curve, r): + x = r.randrange(0, curve.p); + assert lib.bn_is_zero(int2bn(x)) == (not x) + + +def test_simple_comparisons(): + assert lib.bn_is_zero(int2bn(0)) == 1 + assert lib.bn_is_zero(int2bn(1)) == 0 + + assert lib.bn_is_less(int2bn(0), int2bn(0)) == 0 + assert lib.bn_is_less(int2bn(1), int2bn(0)) == 0 + assert lib.bn_is_less(int2bn(0), int2bn(1)) == 1 + + assert lib.bn_is_equal(int2bn(0), int2bn(0)) == 1 + assert lib.bn_is_equal(int2bn(1), int2bn(0)) == 0 + assert lib.bn_is_equal(int2bn(0), int2bn(1)) == 0 + + def test_mult_half(curve, r): x = r.randrange(0, 2*curve.p) y = int2bn(x)