diff --git a/crypto/bignum.c b/crypto/bignum.c index 0219344b6c..1bc2f3698d 100644 --- a/crypto/bignum.c +++ b/crypto/bignum.c @@ -80,6 +80,15 @@ #define BN_MAX_DECIMAL_DIGITS \ 79 // floor(log(2**(LIMBS * BITS_PER_LIMB), 10)) + 1 +// y = (bignum256) x +// Assumes x is normalized and x < 2**261 == 2**(BITS_PER_LIMB * LIMBS) +// Guarantees y is normalized +void bn_copy_lower(const bignum512 *x, bignum256 *y) { + for (int i = 0; i < BN_LIMBS; i++) { + y->val[i] = x->val[i]; + } +} + // out_number = (bignum256) in_number // Assumes in_number is a raw bigendian 256-bit number // Guarantees out_number is normalized @@ -667,12 +676,11 @@ void bn_multiply_reduce_step(bignum512 *res, const bignum256 *prime, res->val[d + BN_LIMBS] = 0; } -// Auxiliary function for bn_multiply -// Partly reduces res and stores both in x and res -// Assumes res in normalized and res < 2**519 +// Partly reduces x +// Assumes x in normalized and res < 2**519 // Guarantees x is normalized and partly reduced modulo prime // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256 -void bn_multiply_reduce(bignum256 *x, bignum512 *res, const bignum256 *prime) { +void bn_reduce(bignum512 *x, const bignum256 *prime) { for (int i = BN_LIMBS - 1; i >= 0; i--) { // res < 2**(256 + 29*i + 31) // Proof: @@ -683,11 +691,7 @@ void bn_multiply_reduce(bignum256 *x, bignum512 *res, const bignum256 *prime) { // else: // res < 2 * prime * 2**(29 * (i + 1)) // <= 2**256 * 2**(29*i + 29) < 2**(256 + 29*i + 31) - bn_multiply_reduce_step(res, prime, i); - } - - for (int i = 0; i < BN_LIMBS; i++) { - x->val[i] = res->val[i]; + bn_multiply_reduce_step(x, prime, i); } } @@ -699,7 +703,8 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) { bignum512 res = {0}; bn_multiply_long(k, x, &res); - bn_multiply_reduce(x, &res, prime); + bn_reduce(&res, prime); + bn_copy_lower(&res, x); memzero(&res, sizeof(res)); } diff --git a/crypto/bignum.h b/crypto/bignum.h index 778228894e..17ef553a5f 100644 --- a/crypto/bignum.h +++ b/crypto/bignum.h @@ -72,6 +72,7 @@ static inline void write_le(uint8_t *data, uint32_t x) { data[0] = x; } +void bn_copy_lower(const bignum512 *x, bignum256 *y); void bn_read_be(const uint8_t *in_number, bignum256 *out_number); void bn_write_be(const bignum256 *in_number, uint8_t *out_number); void bn_read_le(const uint8_t *in_number, bignum256 *out_number); @@ -99,6 +100,7 @@ void bn_mult_half(bignum256 *x, const bignum256 *prime); void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime); void bn_mod(bignum256 *x, const bignum256 *prime); void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime); +void bn_reduce(bignum512 *x, const bignum256 *prime); void bn_fast_mod(bignum256 *x, const bignum256 *prime); void bn_power_mod(const bignum256 *x, const bignum256 *e, const bignum256 *prime, bignum256 *res); diff --git a/crypto/tests/test_bignum.py b/crypto/tests/test_bignum.py index dbea144d70..800d05f6a6 100755 --- a/crypto/tests/test_bignum.py +++ b/crypto/tests/test_bignum.py @@ -197,6 +197,16 @@ class Random(random.Random): return self.rand_bignum(2 * limbs_number) +def assert_bn_copy_lower(x): + x_number = int_to_bignum512(x) + y_number = bignum256() + lib.bn_copy_lower(x_number, y_number) + y = bignum256_to_int(y_number) + + assert bignum_is_normalised(y_number) + assert y == x + + def assert_bn_read_be(in_number): raw_in_number = integer_to_raw_number256(in_number, "big") bn_out_number = bignum256() @@ -458,6 +468,17 @@ def assert_bn_multiply(k, x_old, prime): assert x_new % prime == (k * x_old) % prime +def assert_bn_reduce(x_old, prime): + bn_x = int_to_bignum512(x_old) + bn_prime = int_to_bignum256(prime) + lib.bn_reduce(bn_x, bn_prime) + x_new = bignum256_to_int(bn_x) + + assert bignum_is_normalised(bn_x) + assert number_is_partly_reduced(x_new, prime) + assert x_new % prime == x_old % prime + + def assert_bn_fast_mod(x_old, prime): bn_x = int_to_bignum256(x_old) bn_prime = int_to_bignum256(prime) @@ -731,6 +752,10 @@ def assert_bn_format(x, prefix, suffix, decimals, exponent, trailing, thousands) assert return_value == correct_return_value +def test_bn_copy_lower(r): + assert_bn_copy_lower(r.rand_int_bitsize(261)) + + def test_bn_read_be(r): assert_bn_read_be(r.rand_int_256()) @@ -916,6 +941,11 @@ def test_bn_multiply_reduce_step(r, prime): assert_bn_multiply_reduce_step(res, prime, k) +def test_bn_reduce(r, prime): + x = r.rand_int_bitsize(519) + assert_bn_reduce(x, prime) + + def test_bn_multiply(r, prime): x = r.randrange(floor(sqrt(2**519))) k = r.randrange(floor(sqrt(2**519)))