1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

Cleaned up bignum code

1. Fixed bn_multiply_step to handle small primes.
2. Removed many calls to bn_mod to prevent side-channel leakage.
This commit is contained in:
Jochen Hoenicke 2015-07-23 13:00:54 -07:00
parent 2e09a9ff35
commit 6ba4d288b0
4 changed files with 31 additions and 43 deletions

View File

@ -172,9 +172,7 @@ void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime)
for (j = 0; j < 9; j++) {
x->val[j] = k * x->val[j];
}
bn_normalize(x);
bn_fast_mod(x, prime);
bn_mod(x, prime);
}
// assumes x < 2*prime, result < prime
@ -233,31 +231,35 @@ void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t
// let k = i-8.
// invariants:
// res[0..(i+1)] = k * x (mod prime)
// 0 <= res < 2^(30k + 256) * (2^30 + 1)
// 0 <= res < 2^(30k + 256) * (2^31)
// estimate (res / prime)
// coef = res / 2^(30k + 256) rounded down
// 0 <= coef <= 2^30
// 0 <= coef < 2^31
// subtract (coef * 2^(30k) * prime) from res
// note that we unrolled the first iteration
uint32_t j;
uint32_t coef = (res[i] >> 16) + (res[i + 1] << 14);
uint64_t temp = 0x1000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef;
uint64_t temp = 0x2000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef;
assert (coef < 0x80000000u);
res[i - 8] = temp & 0x3FFFFFFF;
for (j = 1; j < 9; j++) {
temp >>= 30;
temp += 0xFFFFFFFC0000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef;
// Note: coeff * prime->val <= (2^31-1) * (2^30-1)
// Hence, this addition will not underflow.
temp += 0x1FFFFFFF80000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef;
res[i - 8 + j] = temp & 0x3FFFFFFF;
// 0 <= temp < 2^61
}
temp >>= 30;
temp += 0xFFFFFFFC0000000ull + res[i - 8 + j];
temp += 0x1FFFFFFF80000000ull + res[i - 8 + j];
res[i - 8 + j] = temp & 0x3FFFFFFF;
// we rely on the fact that prime > 2^256 - 2^196
// we rely on the fact that prime > 2^256 - 2^224
// res = oldres - coef*2^(30k) * prime;
// and
// coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256)
// Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime))
// Since coef * (2^256 - prime) < 2^226, we get
// 0 <= res < 2^(30k + 226) (2^30 + 1)
// Since coef * (2^256 - prime) < 2^256, we get
// 0 <= res < 2^(30k + 226) (2^31)
// Thus the invariant holds again.
}
@ -270,7 +272,6 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[18], const bignum256 *prime)
// compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime
for (i = 16; i >= 8; i--) {
bn_multiply_reduce_step(res, prime, i);
bn_multiply_reduce_step(res, prime, i); // apply twice, as a hack for NIST256P1 prime.
assert(res[i + 1] == 0);
}
// store the result
@ -294,7 +295,7 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime)
}
// input x can be any normalized number that fits (0 <= x < 2^270).
// prime must be between (2^256 - 2^196) and 2^256
// prime must be between (2^256 - 2^224) and 2^256
// result is smaller than 2*prime
void bn_fast_mod(bignum256 *x, const bignum256 *prime)
{
@ -305,11 +306,11 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime)
coef = x->val[8] >> 16;
// substract (coef * prime) from x
// note that we unrolled the first iteration
temp = 0x1000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef;
temp = 0x2000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef;
x->val[0] = temp & 0x3FFFFFFF;
for (j = 1; j < 9; j++) {
temp >>= 30;
temp += 0xFFFFFFFC0000000ull + x->val[j] - prime->val[j] * (uint64_t)coef;
temp += 0x1FFFFFFF80000000ull + x->val[j] - prime->val[j] * (uint64_t)coef;
x->val[j] = temp & 0x3FFFFFFF;
}
}
@ -679,16 +680,12 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime)
for (i = 0; i < 9; i++) {
a->val[i] += b->val[i];
}
bn_normalize(a);
bn_fast_mod(a, prime);
bn_mod(a, prime);
}
void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime) {
void bn_addi(bignum256 *a, uint32_t b) {
a->val[0] += b;
bn_normalize(a);
bn_fast_mod(a, prime);
bn_mod(a, prime);
}
// res = a - b mod prime. More exactly res = a + (2*prime - b).

View File

@ -75,7 +75,7 @@ void bn_normalize(bignum256 *a);
void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime);
void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime);
void bn_addi(bignum256 *a, uint32_t b);
void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime);

15
ecdsa.c
View File

@ -207,13 +207,10 @@ void curve_to_jacobian(const curve_point *p, jacobian_curve_point *jp, const big
bn_multiply(&p->x, &jp->x, prime);
bn_multiply(&p->y, &jp->y, prime);
bn_mod(&jp->x, prime);
bn_mod(&jp->y, prime);
}
void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const bignum256 *prime) {
p->y = jp->z;
bn_mod(&p->y, prime);
bn_inverse(&p->y, prime);
// p->y = z^-1
p->x = p->y;
@ -298,21 +295,18 @@ void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const b
// z3 = h*z2
bn_multiply(&h, &p2->z, prime);
bn_mod(&p2->z, prime);
// x3 = r^2 - h^3 - 2h^2x2
bn_addmod(&hcb, &hsqx2, prime);
bn_addmod(&hcb, &hsqx2, prime);
bn_subtractmod(&rsq, &hcb, &p2->x, prime);
bn_fast_mod(&p2->x, prime);
bn_mod(&p2->x, prime);
// y3 = r*(h^2x2 - x3) - y2*h^3
bn_subtractmod(&hsqx2, &p2->x, &p2->y, prime);
bn_multiply(&r, &p2->y, prime);
bn_subtractmod(&p2->y, &hcby2, &p2->y, prime);
bn_fast_mod(&p2->y, prime);
bn_mod(&p2->y, prime);
}
void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
@ -366,15 +360,13 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
// z3 = yz
bn_multiply(&p->y, &p->z, prime);
bn_mod(&p->z, prime);
// x3 = m^2 - 2*xy^2
p->x = xysq;
bn_mod(&p->x, prime);
bn_lshift(&p->x);
bn_fast_mod(&p->x, prime);
bn_subtractmod(&msq, &p->x, &p->x, prime);
bn_fast_mod(&p->x, prime);
bn_mod(&p->x, prime);
// y3 = m*(xy^2 - x3) - y^4
bn_subtractmod(&xysq, &p->x, &p->y, prime);
@ -382,7 +374,6 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
bn_multiply(&ysq, &ysq, prime);
bn_subtractmod(&p->y, &ysq, &p->y, prime);
bn_fast_mod(&p->y, prime);
bn_mod(&p->y, prime);
}
// res = k * p
@ -835,7 +826,7 @@ void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x
memcpy(y, x, sizeof(bignum256)); // y is x
bn_multiply(x, y, &curve->prime); // y is x^2
bn_multiply(x, y, &curve->prime); // y is x^3
bn_addmodi(y, 7, &curve->prime); // y is x^3 + 7
bn_addi(y, 7); // y is x^3 + 7
bn_sqrt(y, &curve->prime); // y = sqrt(y)
if ((odd & 0x01) != (y->val[0] & 1)) {
bn_subtract(&curve->prime, y, y); // y = -y
@ -885,7 +876,7 @@ int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub)
// x^3 + b
bn_multiply(&(pub->x), &x_3_b, &curve->prime);
bn_multiply(&(pub->x), &x_3_b, &curve->prime);
bn_addmodi(&x_3_b, 7, &curve->prime);
bn_addi(&x_3_b, 7);
if (!bn_is_equal(&x_3_b, &y_2)) {
return 0;

14
tests.c
View File

@ -40,10 +40,10 @@
#include "secp256k1.h"
#define CURVE (&secp256k1)
#define prime256k1 (secp256k1.prime)
#define G256k1 (secp256k1.G)
#define order256k1 (secp256k1.order)
#define secp256k1_cp (secp256k1.cp)
#define prime256k1 (CURVE->prime)
#define G256k1 (CURVE->G)
#define order256k1 (CURVE->order)
#define secp256k1_cp (CURVE->cp)
uint8_t *fromhex(const char *str)
{
@ -1231,7 +1231,7 @@ START_TEST(test_secp256k1_cp) {
// increment by one and test again
p1 = p;
point_add(CURVE, &G256k1, &p1);
bn_addmodi(&a, 1, &order256k1);
bn_addi(&a, 1);
scalar_multiply(CURVE, &a, &p);
ck_assert_mem_eq(&p, &p1, sizeof(curve_point));
bn_zero(&p.y); // test that point_multiply CURVE, is not a noop
@ -1254,7 +1254,7 @@ START_TEST(test_mult_border_cases) {
point_multiply(CURVE, &a, &G256k1, &p);
ck_assert(point_is_infinity(&p));
bn_addmodi(&a, 1, &order256k1); // a == 1
bn_addi(&a, 1); // a == 1
scalar_multiply(CURVE, &a, &p);
ck_assert_mem_eq(&p, &G256k1, sizeof(curve_point));
point_multiply(CURVE, &a, &G256k1, &p);
@ -1269,7 +1269,7 @@ START_TEST(test_mult_border_cases) {
ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
bn_subtract(&order256k1, &a, &a);
bn_addmodi(&a, 1, &order256k1); // a == 2
bn_addi(&a, 1); // a == 2
expected = G256k1;
point_add(CURVE, &expected, &expected);
scalar_multiply(CURVE, &a, &p);