1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 06:18:07 +00:00

refactor(crypto): introduce bignum512

This commit is contained in:
Ondřej Vejpustek 2023-09-19 14:05:15 +02:00
parent 3a2bdf16dd
commit 2b00c72094
3 changed files with 206 additions and 150 deletions

View File

@ -525,8 +525,7 @@ void bn_mod(bignum256 *x, const bignum256 *prime) {
// res = k * x // res = k * x
// Assumes k and x are normalized // Assumes k and x are normalized
// Guarantees res is normalized 18 digit little endian number in base 2**29 // Guarantees res is normalized 18 digit little endian number in base 2**29
void bn_multiply_long(const bignum256 *k, const bignum256 *x, void bn_multiply_long(const bignum256 *k, const bignum256 *x, bignum512 *res) {
uint32_t res[2 * BN_LIMBS]) {
// Uses long multiplication in base 2**29, see // Uses long multiplication in base 2**29, see
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication // https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication
@ -545,7 +544,7 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x,
// <= 2**35 + 9 * 2**58 < 2**64 // <= 2**35 + 9 * 2**58 < 2**64
} }
res[i] = acc & BN_LIMB_MASK; res->val[i] = acc & BN_LIMB_MASK;
acc >>= BN_BITS_PER_LIMB; acc >>= BN_BITS_PER_LIMB;
// acc <= 2**35 - 1 == 2**(64 - BITS_PER_LIMB) - 1 // acc <= 2**35 - 1 == 2**(64 - BITS_PER_LIMB) - 1
} }
@ -563,12 +562,12 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x,
// <= 2**35 + 9 * 2**58 < 2**64 // <= 2**35 + 9 * 2**58 < 2**64
} }
res[i] = acc & (BN_BASE - 1); res->val[i] = acc & (BN_BASE - 1);
acc >>= BN_BITS_PER_LIMB; acc >>= BN_BITS_PER_LIMB;
// acc < 2**35 == 2**(64 - BITS_PER_LIMB) // acc < 2**35 == 2**(64 - BITS_PER_LIMB)
} }
res[2 * BN_LIMBS - 1] = acc; res->val[2 * BN_LIMBS - 1] = acc;
} }
// Auxiliary function for bn_multiply // Auxiliary function for bn_multiply
@ -576,7 +575,7 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x,
// Assumes res is normalized and res < 2**(256 + 29*d + 31) // Assumes res is normalized and res < 2**(256 + 29*d + 31)
// Guarantess res in normalized and res < 2 * prime * 2**(29*d) // Guarantess res in normalized and res < 2 * prime * 2**(29*d)
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256 // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime, void bn_multiply_reduce_step(bignum512 *res, const bignum256 *prime,
uint32_t d) { uint32_t d) {
// clang-format off // clang-format off
// Computes res = res - (res // 2**(256 + BITS_PER_LIMB * d)) * prime * 2**(BITS_PER_LIMB * d) // Computes res = res - (res // 2**(256 + BITS_PER_LIMB * d)) * prime * 2**(BITS_PER_LIMB * d)
@ -598,8 +597,9 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// clang-format on // clang-format on
uint32_t coef = uint32_t coef =
(res[d + BN_LIMBS - 1] >> (256 - (BN_LIMBS - 1) * BN_BITS_PER_LIMB)) + (res->val[d + BN_LIMBS - 1] >>
(res[d + BN_LIMBS] << ((BN_LIMBS * BN_BITS_PER_LIMB) - 256)); (256 - (BN_LIMBS - 1) * BN_BITS_PER_LIMB)) +
(res->val[d + BN_LIMBS] << ((BN_LIMBS * BN_BITS_PER_LIMB) - 256));
// coef == res // 2**(256 + BITS_PER_LIMB * d) // coef == res // 2**(256 + BITS_PER_LIMB * d)
@ -613,7 +613,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
uint64_t acc = 1ull << shift; uint64_t acc = 1ull << shift;
for (int i = 0; i < BN_LIMBS; i++) { for (int i = 0; i < BN_LIMBS; i++) {
acc += (((uint64_t)(BN_BASE - 1)) << shift) + res[d + i] - acc += (((uint64_t)(BN_BASE - 1)) << shift) + res->val[d + i] -
prime->val[i] * (uint64_t)coef; prime->val[i] * (uint64_t)coef;
// acc neither overflow 64 bits nor underflow zero // acc neither overflow 64 bits nor underflow zero
// Proof: // Proof:
@ -633,7 +633,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// == (2**35 - 1) + (2**31 + 1) * (2**29 - 1) // == (2**35 - 1) + (2**31 + 1) * (2**29 - 1)
// <= 2**35 + 2**60 + 2**29 < 2**64 // <= 2**35 + 2**60 + 2**29 < 2**64
res[d + i] = acc & BN_LIMB_MASK; res->val[d + i] = acc & BN_LIMB_MASK;
acc >>= BN_BITS_PER_LIMB; acc >>= BN_BITS_PER_LIMB;
// acc <= 2**(64 - BITS_PER_LIMB) - 1 == 2**35 - 1 // acc <= 2**(64 - BITS_PER_LIMB) - 1 == 2**35 - 1
@ -664,7 +664,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// == 1 << shift // == 1 << shift
// clang-format on // clang-format on
res[d + BN_LIMBS] = 0; res->val[d + BN_LIMBS] = 0;
} }
// Auxiliary function for bn_multiply // Auxiliary function for bn_multiply
@ -672,8 +672,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// Assumes res in normalized and res < 2**519 // Assumes res in normalized and res < 2**519
// Guarantees x is normalized and partly reduced modulo prime // Guarantees x is normalized and partly reduced modulo prime
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256 // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS], void bn_multiply_reduce(bignum256 *x, bignum512 *res, const bignum256 *prime) {
const bignum256 *prime) {
for (int i = BN_LIMBS - 1; i >= 0; i--) { for (int i = BN_LIMBS - 1; i >= 0; i--) {
// res < 2**(256 + 29*i + 31) // res < 2**(256 + 29*i + 31)
// Proof: // Proof:
@ -688,7 +687,7 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS],
} }
for (int i = 0; i < BN_LIMBS; i++) { for (int i = 0; i < BN_LIMBS; i++) {
x->val[i] = res[i]; x->val[i] = res->val[i];
} }
} }
@ -697,12 +696,12 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS],
// Guarantees x is normalized and partly reduced modulo prime // Guarantees x is normalized and partly reduced modulo prime
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256 // Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) { void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) {
uint32_t res[2 * BN_LIMBS] = {0}; bignum512 res = {0};
bn_multiply_long(k, x, res); bn_multiply_long(k, x, &res);
bn_multiply_reduce(x, res, prime); bn_multiply_reduce(x, &res, prime);
memzero(res, sizeof(res)); memzero(&res, sizeof(res));
} }
// Partly reduces x modulo prime // Partly reduces x modulo prime

View File

@ -43,6 +43,11 @@ typedef struct {
uint32_t val[BN_LIMBS]; uint32_t val[BN_LIMBS];
} bignum256; } bignum256;
// Represents the number sum([val[i] * 2**(29*i) for i in range(18))
typedef struct {
uint32_t val[2 * BN_LIMBS];
} bignum512;
static inline uint32_t read_be(const uint8_t *data) { static inline uint32_t read_be(const uint8_t *data) {
return (((uint32_t)data[0]) << 24) | (((uint32_t)data[1]) << 16) | return (((uint32_t)data[0]) << 24) | (((uint32_t)data[1]) << 16) |
(((uint32_t)data[2]) << 8) | (((uint32_t)data[3])); (((uint32_t)data[2]) << 8) | (((uint32_t)data[3]));

View File

@ -69,15 +69,23 @@ def uint32_p():
limb_type = c_uint32 limb_type = c_uint32
def bignum(limbs_number=limbs_number): def bignum(limbs_number):
return (limbs_number * limb_type)() return (limbs_number * limb_type)()
def bignum256():
return bignum(limbs_number)
def bignum512():
return bignum(2 * limbs_number)
def limbs_to_bignum(limbs): def limbs_to_bignum(limbs):
return (limbs_number * limb_type)(*limbs) return (limbs_number * limb_type)(*limbs)
def int_to_bignum(number, limbs_number=limbs_number): def int_to_bignum(number, limbs_number):
assert number >= 0 assert number >= 0
assert number.bit_length() <= limbs_number * bits_per_limb assert number.bit_length() <= limbs_number * bits_per_limb
@ -89,7 +97,15 @@ def int_to_bignum(number, limbs_number=limbs_number):
return bn return bn
def bignum_to_int(bignum, limbs_number=limbs_number): def int_to_bignum256(number):
return int_to_bignum(number, limbs_number)
def int_to_bignum512(number):
return int_to_bignum(number, 2 * limbs_number)
def bignum_to_int(bignum, limbs_number):
number = 0 number = 0
for i in reversed(range(limbs_number)): for i in reversed(range(limbs_number)):
@ -99,16 +115,40 @@ def bignum_to_int(bignum, limbs_number=limbs_number):
return number return number
def raw_number(): def bignum256_to_int(bignum):
return (32 * c_uint8)() return bignum_to_int(bignum, limbs_number)
def bignum512_to_int(bignum):
return bignum_to_int(bignum, 2 * limbs_number)
def raw_number(byte_size):
return (byte_size * c_uint8)()
def raw_number256():
return raw_number(32)
def raw_number512():
return raw_number(64)
def raw_number_to_integer(raw_number, endianess): def raw_number_to_integer(raw_number, endianess):
return int.from_bytes(raw_number, endianess) return int.from_bytes(raw_number, endianess)
def integer_to_raw_number(number, endianess): def integer_to_raw_number(number, endianess, byte_size):
return (32 * c_uint8)(*number.to_bytes(32, endianess)) return (byte_size * c_uint8)(*number.to_bytes(byte_size, endianess))
def integer_to_raw_number256(number, endianess):
return integer_to_raw_number(number, endianess, 32)
def integer_to_raw_number512(number, endianess):
return integer_to_raw_number(number, endianess, 64)
def bignum_is_normalised(bignum): def bignum_is_normalised(bignum):
@ -133,6 +173,9 @@ class Random(random.Random):
def rand_int_256(self): def rand_int_256(self):
return self.randrange(0, 2**256) return self.randrange(0, 2**256)
def rand_int_512(self):
return self.randrange(0, 2**512)
def rand_int_reduced(self, p): def rand_int_reduced(self, p):
return self.randrange(0, 2 * p) return self.randrange(0, 2 * p)
@ -142,35 +185,44 @@ class Random(random.Random):
def rand_bit_index(self): def rand_bit_index(self):
return self.randrange(0, limbs_number * bits_per_limb) return self.randrange(0, limbs_number * bits_per_limb)
def rand_bignum(self, limbs_number=limbs_number): def rand_bignum(self, limbs_number):
return (limb_type * limbs_number)( return (limb_type * limbs_number)(
*[self.randrange(0, 256**4) for _ in range(limbs_number)] *[self.randrange(0, 256**4) for _ in range(limbs_number)]
) )
def rand_bignum256(self):
return self.rand_bignum(limbs_number)
def rand_bignum512(self):
return self.rand_bignum(2 * limbs_number)
def assert_bn_read_be(in_number): def assert_bn_read_be(in_number):
raw_in_number = integer_to_raw_number(in_number, "big") raw_in_number = integer_to_raw_number256(in_number, "big")
bn_out_number = bignum() bn_out_number = bignum256()
lib.bn_read_be(raw_in_number, bn_out_number) lib.bn_read_be(raw_in_number, bn_out_number)
out_number = bignum_to_int(bn_out_number) out_number = bignum256_to_int(bn_out_number)
assert bignum_is_normalised(bn_out_number)
assert out_number == in_number
assert bignum_is_normalised(bn_out_number) assert bignum_is_normalised(bn_out_number)
assert out_number == in_number assert out_number == in_number
def assert_bn_read_le(in_number): def assert_bn_read_le(in_number):
raw_in_number = integer_to_raw_number(in_number, "little") raw_in_number = integer_to_raw_number256(in_number, "little")
bn_out_number = bignum() bn_out_number = bignum256()
lib.bn_read_le(raw_in_number, bn_out_number) lib.bn_read_le(raw_in_number, bn_out_number)
out_number = bignum_to_int(bn_out_number) out_number = bignum256_to_int(bn_out_number)
assert bignum_is_normalised(bn_out_number) assert bignum_is_normalised(bn_out_number)
assert out_number == in_number assert out_number == in_number
def assert_bn_write_be(in_number): def assert_bn_write_be(in_number):
bn_in_number = int_to_bignum(in_number) bn_in_number = int_to_bignum256(in_number)
raw_out_number = raw_number() raw_out_number = raw_number256()
lib.bn_write_be(bn_in_number, raw_out_number) lib.bn_write_be(bn_in_number, raw_out_number)
out_number = raw_number_to_integer(raw_out_number, "big") out_number = raw_number_to_integer(raw_out_number, "big")
@ -178,8 +230,8 @@ def assert_bn_write_be(in_number):
def assert_bn_write_le(in_number): def assert_bn_write_le(in_number):
bn_in_number = int_to_bignum(in_number) bn_in_number = int_to_bignum256(in_number)
raw_out_number = raw_number() raw_out_number = raw_number256()
lib.bn_write_le(bn_in_number, raw_out_number) lib.bn_write_le(bn_in_number, raw_out_number)
out_number = raw_number_to_integer(raw_out_number, "little") out_number = raw_number_to_integer(raw_out_number, "little")
@ -187,100 +239,100 @@ def assert_bn_write_le(in_number):
def assert_bn_read_uint32(x): def assert_bn_read_uint32(x):
bn_out_number = bignum() bn_out_number = bignum256()
lib.bn_read_uint32(c_uint32(x), bn_out_number) lib.bn_read_uint32(c_uint32(x), bn_out_number)
out_number = bignum_to_int(bn_out_number) out_number = bignum256_to_int(bn_out_number)
assert bignum_is_normalised(bn_out_number) assert bignum_is_normalised(bn_out_number)
assert out_number == x assert out_number == x
def assert_bn_read_uint64(x): def assert_bn_read_uint64(x):
bn_out_number = bignum() bn_out_number = bignum256()
lib.bn_read_uint64(c_uint64(x), bn_out_number) lib.bn_read_uint64(c_uint64(x), bn_out_number)
out_number = bignum_to_int(bn_out_number) out_number = bignum256_to_int(bn_out_number)
assert bignum_is_normalised(bn_out_number) assert bignum_is_normalised(bn_out_number)
assert out_number == x assert out_number == x
def assert_bn_bitcount(x): def assert_bn_bitcount(x):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
return_value = lib.bn_bitcount(bn_x) return_value = lib.bn_bitcount(bn_x)
assert return_value == x.bit_length() assert return_value == x.bit_length()
def assert_bn_digitcount(x): def assert_bn_digitcount(x):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
return_value = lib.bn_digitcount(bn_x) return_value = lib.bn_digitcount(bn_x)
assert return_value == len(str(x)) assert return_value == len(str(x))
def assert_bn_zero(): def assert_bn_zero():
bn_x = bignum() bn_x = bignum256()
lib.bn_zero(bn_x) lib.bn_zero(bn_x)
x = bignum_to_int(bn_x) x = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x == 0 assert x == 0
def assert_bn_one(): def assert_bn_one():
bn_x = bignum() bn_x = bignum256()
lib.bn_one(bn_x) lib.bn_one(bn_x)
x = bignum_to_int(bn_x) x = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x == 1 assert x == 1
def assert_bn_is_zero(x): def assert_bn_is_zero(x):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
return_value = lib.bn_is_zero(bn_x) return_value = lib.bn_is_zero(bn_x)
assert return_value == (x == 0) assert return_value == (x == 0)
def assert_bn_is_one(x): def assert_bn_is_one(x):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
return_value = lib.bn_is_one(bn_x) return_value = lib.bn_is_one(bn_x)
assert return_value == (x == 1) assert return_value == (x == 1)
def assert_bn_is_less(x, y): def assert_bn_is_less(x, y):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
return_value = lib.bn_is_less(bn_x, bn_y) return_value = lib.bn_is_less(bn_x, bn_y)
assert return_value == (x < y) assert return_value == (x < y)
def assert_bn_is_equal(x, y): def assert_bn_is_equal(x, y):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
return_value = lib.bn_is_equal(bn_x, bn_y) return_value = lib.bn_is_equal(bn_x, bn_y)
assert return_value == (x == y) assert return_value == (x == y)
def assert_bn_cmov(cond, truecase, falsecase): def assert_bn_cmov(cond, truecase, falsecase):
bn_res = bignum() bn_res = bignum256()
bn_truecase = int_to_bignum(truecase) bn_truecase = int_to_bignum256(truecase)
bn_falsecase = int_to_bignum(falsecase) bn_falsecase = int_to_bignum256(falsecase)
lib.bn_cmov(bn_res, c_uint32(cond), bn_truecase, bn_falsecase) lib.bn_cmov(bn_res, c_uint32(cond), bn_truecase, bn_falsecase)
res = bignum_to_int(bn_res) res = bignum256_to_int(bn_res)
assert res == truecase if cond else falsecase assert res == truecase if cond else falsecase
def assert_bn_cnegate(cond, x_old, prime): def assert_bn_cnegate(cond, x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_cnegate(c_uint32(cond), bn_x, bn_prime) lib.bn_cnegate(c_uint32(cond), bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -288,63 +340,63 @@ def assert_bn_cnegate(cond, x_old, prime):
def assert_bn_lshift(x_old): def assert_bn_lshift(x_old):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
lib.bn_lshift(bn_x) lib.bn_lshift(bn_x)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == (x_old << 1) assert x_new == (x_old << 1)
def assert_bn_rshift(x_old): def assert_bn_rshift(x_old):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
lib.bn_rshift(bn_x) lib.bn_rshift(bn_x)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == (x_old >> 1) assert x_new == (x_old >> 1)
def assert_bn_setbit(x_old, i): def assert_bn_setbit(x_old, i):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
lib.bn_setbit(bn_x, c_uint16(i)) lib.bn_setbit(bn_x, c_uint16(i))
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == x_old | (1 << i) assert x_new == x_old | (1 << i)
def assert_bn_clearbit(x_old, i): def assert_bn_clearbit(x_old, i):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
lib.bn_clearbit(bn_x, c_uint16(i)) lib.bn_clearbit(bn_x, c_uint16(i))
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == x_old & ~(1 << i) assert x_new == x_old & ~(1 << i)
def assert_bn_testbit(x_old, i): def assert_bn_testbit(x_old, i):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
return_value = lib.bn_testbit(bn_x, c_uint16(i)) return_value = lib.bn_testbit(bn_x, c_uint16(i))
assert return_value == x_old >> i & 1 assert return_value == x_old >> i & 1
def assert_bn_xor(x, y): def assert_bn_xor(x, y):
bn_res = bignum() bn_res = bignum256()
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
lib.bn_xor(bn_res, bn_x, bn_y) lib.bn_xor(bn_res, bn_x, bn_y)
res = bignum_to_int(bn_res) res = bignum256_to_int(bn_res)
assert res == x ^ y assert res == x ^ y
def assert_bn_mult_half(x_old, prime): def assert_bn_mult_half(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_mult_half(bn_x, bn_prime) lib.bn_mult_half(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert implication( assert implication(
number_is_partly_reduced(x_old, prime), number_is_partly_reduced(x_new, prime) number_is_partly_reduced(x_old, prime), number_is_partly_reduced(x_new, prime)
@ -353,10 +405,10 @@ def assert_bn_mult_half(x_old, prime):
def assert_bn_mult_k(x_old, k, prime): def assert_bn_mult_k(x_old, k, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_mult_k(bn_x, c_uint8(k), bn_prime) lib.bn_mult_k(bn_x, c_uint8(k), bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -364,10 +416,10 @@ def assert_bn_mult_k(x_old, k, prime):
def assert_bn_mod(x_old, prime): def assert_bn_mod(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_mod(bn_x, bn_prime) lib.bn_mod(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_fully_reduced(x_new, prime) assert number_is_fully_reduced(x_new, prime)
@ -375,31 +427,31 @@ def assert_bn_mod(x_old, prime):
def assert_bn_multiply_long(k_old, x_old): def assert_bn_multiply_long(k_old, x_old):
bn_k = int_to_bignum(k_old) bn_k = int_to_bignum256(k_old)
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_res = bignum(2 * limbs_number) bn_res = bignum512()
lib.bn_multiply_long(bn_k, bn_x, bn_res) lib.bn_multiply_long(bn_k, bn_x, bn_res)
res = bignum_to_int(bn_res, 2 * limbs_number) res = bignum512_to_int(bn_res)
assert res == k_old * x_old assert res == k_old * x_old
def assert_bn_multiply_reduce_step(res_old, prime, d): def assert_bn_multiply_reduce_step(res_old, prime, d):
bn_res = int_to_bignum(res_old, 2 * limbs_number) bn_res = int_to_bignum512(res_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_multiply_reduce_step(bn_res, bn_prime, d) lib.bn_multiply_reduce_step(bn_res, bn_prime, d)
res_new = bignum_to_int(bn_res, 2 * limbs_number) res_new = bignum512_to_int(bn_res)
assert bignum_is_normalised(bn_res) assert bignum_is_normalised(bn_res)
assert res_new < 2 * prime * 2 ** (d * bits_per_limb) assert res_new < 2 * prime * 2 ** (d * bits_per_limb)
def assert_bn_multiply(k, x_old, prime): def assert_bn_multiply(k, x_old, prime):
bn_k = int_to_bignum(k) bn_k = int_to_bignum256(k)
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_multiply(bn_k, bn_x, bn_prime) lib.bn_multiply(bn_k, bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -407,10 +459,10 @@ def assert_bn_multiply(k, x_old, prime):
def assert_bn_fast_mod(x_old, prime): def assert_bn_fast_mod(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_fast_mod(bn_x, bn_prime) lib.bn_fast_mod(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -419,10 +471,10 @@ def assert_bn_fast_mod(x_old, prime):
def assert_bn_fast_mod_bn(bn_x, prime): def assert_bn_fast_mod_bn(bn_x, prime):
bn_x bn_x
x_old = bignum_to_int(bn_x) x_old = bignum256_to_int(bn_x)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_fast_mod(bn_x, bn_prime) lib.bn_fast_mod(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -430,12 +482,12 @@ def assert_bn_fast_mod_bn(bn_x, prime):
def assert_bn_power_mod(x, e, prime): def assert_bn_power_mod(x, e, prime):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_e = int_to_bignum(e) bn_e = int_to_bignum256(e)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
bn_res_new = bignum() bn_res_new = bignum256()
lib.bn_power_mod(bn_x, bn_e, bn_prime, bn_res_new) lib.bn_power_mod(bn_x, bn_e, bn_prime, bn_res_new)
res_new = bignum_to_int(bn_res_new) res_new = bignum256_to_int(bn_res_new)
assert bignum_is_normalised(bn_res_new) assert bignum_is_normalised(bn_res_new)
assert number_is_partly_reduced(res_new, prime) assert number_is_partly_reduced(res_new, prime)
@ -443,10 +495,10 @@ def assert_bn_power_mod(x, e, prime):
def assert_bn_sqrt(x_old, prime): def assert_bn_sqrt(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_sqrt(bn_x, bn_prime) lib.bn_sqrt(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_fully_reduced(x_new, prime) assert number_is_fully_reduced(x_new, prime)
@ -460,10 +512,10 @@ def assert_inverse_mod_power_two(x, m):
def assert_bn_divide_base(x_old, prime): def assert_bn_divide_base(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_divide_base(bn_x, bn_prime) lib.bn_divide_base(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert implication( assert implication(
number_is_fully_reduced(x_old, prime), number_is_fully_reduced(x_new, prime) number_is_fully_reduced(x_old, prime), number_is_fully_reduced(x_new, prime)
@ -475,10 +527,10 @@ def assert_bn_divide_base(x_old, prime):
def assert_bn_inverse(x_old, prime): def assert_bn_inverse(x_old, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_inverse(bn_x, bn_prime) lib.bn_inverse(bn_x, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_fully_reduced(x_new, prime) assert number_is_fully_reduced(x_new, prime)
@ -486,31 +538,31 @@ def assert_bn_inverse(x_old, prime):
def assert_bn_normalize(bn_x): def assert_bn_normalize(bn_x):
x_old = bignum_to_int(bn_x) x_old = bignum256_to_int(bn_x)
lib.bn_normalize(bn_x) lib.bn_normalize(bn_x)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert x_new == x_old % 2 ** (bits_per_limb * limbs_number) assert x_new == x_old % 2 ** (bits_per_limb * limbs_number)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
def assert_bn_add(x_old, y): def assert_bn_add(x_old, y):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
lib.bn_add(bn_x, bn_y) lib.bn_add(bn_x, bn_y)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
y = bignum_to_int(bn_y) y = bignum256_to_int(bn_y)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == x_old + y assert x_new == x_old + y
def assert_bn_addmod(x_old, y, prime): def assert_bn_addmod(x_old, y, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_addmod(bn_x, bn_y, bn_prime) lib.bn_addmod(bn_x, bn_y, bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert number_is_partly_reduced(x_new, prime) assert number_is_partly_reduced(x_new, prime)
@ -518,19 +570,19 @@ def assert_bn_addmod(x_old, y, prime):
def assert_bn_addi(x_old, y): def assert_bn_addi(x_old, y):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
lib.bn_addi(bn_x, c_uint32(y)) lib.bn_addi(bn_x, c_uint32(y))
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert x_new == x_old + y assert x_new == x_old + y
def assert_bn_subi(x_old, y, prime): def assert_bn_subi(x_old, y, prime):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
lib.bn_subi(bn_x, c_uint32(y), bn_prime) lib.bn_subi(bn_x, c_uint32(y), bn_prime)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert implication( assert implication(
@ -540,12 +592,12 @@ def assert_bn_subi(x_old, y, prime):
def assert_bn_subtractmod(x, y, prime): def assert_bn_subtractmod(x, y, prime):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
bn_res = bignum() bn_res = bignum256()
lib.bn_subtractmod(bn_x, bn_y, bn_res, bn_prime) lib.bn_subtractmod(bn_x, bn_y, bn_res, bn_prime)
res = bignum_to_int(bn_res) res = bignum256_to_int(bn_res)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert res % prime == (x - y) % prime assert res % prime == (x - y) % prime
@ -559,31 +611,31 @@ def legendre(x, prime):
def assert_bn_legendre(x, prime): def assert_bn_legendre(x, prime):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_prime = int_to_bignum(prime) bn_prime = int_to_bignum256(prime)
return_value = lib.bn_legendre(bn_x, bn_prime) return_value = lib.bn_legendre(bn_x, bn_prime)
assert return_value == legendre(x, prime) assert return_value == legendre(x, prime)
def assert_bn_subtract(x, y): def assert_bn_subtract(x, y):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_y = int_to_bignum(y) bn_y = int_to_bignum256(y)
bn_res = bignum() bn_res = bignum256()
lib.bn_subtract(bn_x, bn_y, bn_res) lib.bn_subtract(bn_x, bn_y, bn_res)
res = bignum_to_int(bn_res) res = bignum256_to_int(bn_res)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
assert res == x - y assert res == x - y
def assert_bn_long_division(x, d): def assert_bn_long_division(x, d):
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
bn_q = bignum() bn_q = bignum256()
uint32_p_r = uint32_p() uint32_p_r = uint32_p()
lib.bn_long_division(bn_x, d, bn_q, uint32_p_r) lib.bn_long_division(bn_x, d, bn_q, uint32_p_r)
r = uint32_p_to_int(uint32_p_r) r = uint32_p_to_int(uint32_p_r)
q = bignum_to_int(bn_q) q = bignum256_to_int(bn_q)
assert bignum_is_normalised(bn_q) assert bignum_is_normalised(bn_q)
assert q == x // d assert q == x // d
@ -591,10 +643,10 @@ def assert_bn_long_division(x, d):
def assert_bn_divmod58(x_old): def assert_bn_divmod58(x_old):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
uint32_p_r = uint32_p() uint32_p_r = uint32_p()
lib.bn_divmod58(bn_x, uint32_p_r) lib.bn_divmod58(bn_x, uint32_p_r)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
r = uint32_p_to_int(uint32_p_r) r = uint32_p_to_int(uint32_p_r)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
@ -603,10 +655,10 @@ def assert_bn_divmod58(x_old):
def assert_bn_divmod1000(x_old): def assert_bn_divmod1000(x_old):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
uint32_p_r = uint32_p() uint32_p_r = uint32_p()
lib.bn_divmod1000(bn_x, uint32_p_r) lib.bn_divmod1000(bn_x, uint32_p_r)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
r = uint32_p_to_int(uint32_p_r) r = uint32_p_to_int(uint32_p_r)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
@ -615,10 +667,10 @@ def assert_bn_divmod1000(x_old):
def assert_bn_divmod10(x_old): def assert_bn_divmod10(x_old):
bn_x = int_to_bignum(x_old) bn_x = int_to_bignum256(x_old)
uint32_p_r = uint32_p() uint32_p_r = uint32_p()
lib.bn_divmod10(bn_x, uint32_p_r) lib.bn_divmod10(bn_x, uint32_p_r)
x_new = bignum_to_int(bn_x) x_new = bignum256_to_int(bn_x)
r = uint32_p_to_int(uint32_p_r) r = uint32_p_to_int(uint32_p_r)
assert bignum_is_normalised(bn_x) assert bignum_is_normalised(bn_x)
@ -654,7 +706,7 @@ def assert_bn_format(x, prefix, suffix, decimals, exponent, trailing, thousands)
def char_p_to_string(pointer): def char_p_to_string(pointer):
return str(pointer.value, "ascii") return str(pointer.value, "ascii")
bn_x = int_to_bignum(x) bn_x = int_to_bignum256(x)
output_length = 100 output_length = 100
output = string_to_char_p("?" * output_length) output = string_to_char_p("?" * output_length)
return_value = lib.bn_format( return_value = lib.bn_format(
@ -875,7 +927,7 @@ def test_bn_fast_mod_1(r, prime):
def test_bn_fast_mod_2(r, prime): def test_bn_fast_mod_2(r, prime):
bn_x = r.rand_bignum() bn_x = r.rand_bignum256()
assert_bn_fast_mod_bn(bn_x, prime) assert_bn_fast_mod_bn(bn_x, prime)
@ -929,7 +981,7 @@ def test_bn_inverse_2(r, prime):
def test_bn_normalize(r): def test_bn_normalize(r):
assert_bn_normalize(r.rand_bignum()) assert_bn_normalize(r.rand_bignum256())
def test_bn_add_1(r): def test_bn_add_1(r):