1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-27 01:48:17 +00:00

Merge pull request #44 from jhoenicke/bignumcleanup

My bignum cleanup branch
This commit is contained in:
Pavol Rusnak 2015-08-06 00:31:29 +02:00
commit 57166295c4
8 changed files with 311 additions and 213 deletions

View File

@ -116,6 +116,19 @@ int bn_is_equal(const bignum256 *a, const bignum256 *b) {
return !result; return !result;
} }
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase)
{
int i;
uint32_t tmask = (uint32_t) -cond;
uint32_t fmask = ~tmask;
assert (cond == 1 || cond == 0);
for (i = 0; i < 9; i++) {
res->val[i] = (truecase->val[i] & tmask) |
(falsecase->val[i] & fmask);
}
}
int bn_bitlen(const bignum256 *a) { int bn_bitlen(const bignum256 *a) {
int i = 8, j; int i = 8, j;
while (i >= 0 && a->val[i] == 0) i--; while (i >= 0 && a->val[i] == 0) i--;
@ -172,9 +185,7 @@ void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime)
for (j = 0; j < 9; j++) { for (j = 0; j < 9; j++) {
x->val[j] = k * x->val[j]; x->val[j] = k * x->val[j];
} }
bn_normalize(x);
bn_fast_mod(x, prime); bn_fast_mod(x, prime);
bn_mod(x, prime);
} }
// assumes x < 2*prime, result < prime // assumes x < 2*prime, result < prime
@ -233,31 +244,35 @@ void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t
// let k = i-8. // let k = i-8.
// invariants: // invariants:
// res[0..(i+1)] = k * x (mod prime) // res[0..(i+1)] = k * x (mod prime)
// 0 <= res < 2^(30k + 256) * (2^30 + 1) // 0 <= res < 2^(30k + 256) * (2^31)
// estimate (res / prime) // estimate (res / prime)
// coef = res / 2^(30k + 256) rounded down // coef = res / 2^(30k + 256) rounded down
// 0 <= coef <= 2^30 // 0 <= coef < 2^31
// subtract (coef * 2^(30k) * prime) from res // subtract (coef * 2^(30k) * prime) from res
// note that we unrolled the first iteration // note that we unrolled the first iteration
uint32_t j; uint32_t j;
uint32_t coef = (res[i] >> 16) + (res[i + 1] << 14); uint32_t coef = (res[i] >> 16) + (res[i + 1] << 14);
uint64_t temp = 0x1000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef; uint64_t temp = 0x2000000000000000ull + res[i - 8] - prime->val[0] * (uint64_t)coef;
assert (coef < 0x80000000u);
res[i - 8] = temp & 0x3FFFFFFF; res[i - 8] = temp & 0x3FFFFFFF;
for (j = 1; j < 9; j++) { for (j = 1; j < 9; j++) {
temp >>= 30; temp >>= 30;
temp += 0xFFFFFFFC0000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef; // Note: coeff * prime->val <= (2^31-1) * (2^30-1)
// Hence, this addition will not underflow.
temp += 0x1FFFFFFF80000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef;
res[i - 8 + j] = temp & 0x3FFFFFFF; res[i - 8 + j] = temp & 0x3FFFFFFF;
// 0 <= temp < 2^61
} }
temp >>= 30; temp >>= 30;
temp += 0xFFFFFFFC0000000ull + res[i - 8 + j]; temp += 0x1FFFFFFF80000000ull + res[i - 8 + j];
res[i - 8 + j] = temp & 0x3FFFFFFF; res[i - 8 + j] = temp & 0x3FFFFFFF;
// we rely on the fact that prime > 2^256 - 2^196 // we rely on the fact that prime > 2^256 - 2^224
// res = oldres - coef*2^(30k) * prime; // res = oldres - coef*2^(30k) * prime;
// and // and
// coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256) // coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256)
// Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime)) // Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime))
// Since coef * (2^256 - prime) < 2^226, we get // Since coef * (2^256 - prime) < 2^256, we get
// 0 <= res < 2^(30k + 226) (2^30 + 1) // 0 <= res < 2^(30k + 226) (2^31)
// Thus the invariant holds again. // Thus the invariant holds again.
} }
@ -269,9 +284,8 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[18], const bignum256 *prime)
// 0 <= res < 2^526. // 0 <= res < 2^526.
// compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime // compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime
for (i = 16; i >= 8; i--) { for (i = 16; i >= 8; i--) {
bn_multiply_reduce_step(res, prime, i); bn_multiply_reduce_step(res, prime, i);
bn_multiply_reduce_step(res, prime, i); // apply twice, as a hack for NIST256P1 prime. assert(res[i + 1] == 0);
assert(res[i + 1] == 0);
} }
// store the result // store the result
for (i = 0; i < 9; i++) { for (i = 0; i < 9; i++) {
@ -294,7 +308,7 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime)
} }
// input x can be any normalized number that fits (0 <= x < 2^270). // input x can be any normalized number that fits (0 <= x < 2^270).
// prime must be between (2^256 - 2^196) and 2^256 // prime must be between (2^256 - 2^224) and 2^256
// result is smaller than 2*prime // result is smaller than 2*prime
void bn_fast_mod(bignum256 *x, const bignum256 *prime) void bn_fast_mod(bignum256 *x, const bignum256 *prime)
{ {
@ -305,11 +319,11 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime)
coef = x->val[8] >> 16; coef = x->val[8] >> 16;
// substract (coef * prime) from x // substract (coef * prime) from x
// note that we unrolled the first iteration // note that we unrolled the first iteration
temp = 0x1000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef; temp = 0x2000000000000000ull + x->val[0] - prime->val[0] * (uint64_t)coef;
x->val[0] = temp & 0x3FFFFFFF; x->val[0] = temp & 0x3FFFFFFF;
for (j = 1; j < 9; j++) { for (j = 1; j < 9; j++) {
temp >>= 30; temp >>= 30;
temp += 0xFFFFFFFC0000000ull + x->val[j] - prime->val[j] * (uint64_t)coef; temp += 0x1FFFFFFF80000000ull + x->val[j] - prime->val[j] * (uint64_t)coef;
x->val[j] = temp & 0x3FFFFFFF; x->val[j] = temp & 0x3FFFFFFF;
} }
} }
@ -679,16 +693,21 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime)
for (i = 0; i < 9; i++) { for (i = 0; i < 9; i++) {
a->val[i] += b->val[i]; a->val[i] += b->val[i];
} }
bn_normalize(a);
bn_fast_mod(a, prime); bn_fast_mod(a, prime);
bn_mod(a, prime);
} }
void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime) { void bn_addi(bignum256 *a, uint32_t b) {
a->val[0] += b; a->val[0] += b;
bn_normalize(a); bn_normalize(a);
}
void bn_subi(bignum256 *a, uint32_t b, const bignum256 *prime) {
int i;
for (i = 0; i < 9; i++) {
a->val[i] += prime->val[i];
}
a->val[0] -= b;
bn_fast_mod(a, prime); bn_fast_mod(a, prime);
bn_mod(a, prime);
} }
// res = a - b mod prime. More exactly res = a + (2*prime - b). // res = a - b mod prime. More exactly res = a + (2*prime - b).

View File

@ -51,6 +51,8 @@ int bn_is_less(const bignum256 *a, const bignum256 *b);
int bn_is_equal(const bignum256 *a, const bignum256 *b); int bn_is_equal(const bignum256 *a, const bignum256 *b);
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase);
int bn_bitlen(const bignum256 *a); int bn_bitlen(const bignum256 *a);
void bn_lshift(bignum256 *a); void bn_lshift(bignum256 *a);
@ -75,7 +77,9 @@ void bn_normalize(bignum256 *a);
void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime); void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime);
void bn_addmodi(bignum256 *a, uint32_t b, const bignum256 *prime); void bn_addi(bignum256 *a, uint32_t b);
void bn_subi(bignum256 *a, uint32_t b, const bignum256 *prime);
void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime); void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime);

183
ecdsa.c
View File

@ -111,7 +111,7 @@ void point_double(const ecdsa_curve *curve, curve_point *cp)
xr = cp->x; xr = cp->x;
bn_multiply(&xr, &xr, &curve->prime); bn_multiply(&xr, &xr, &curve->prime);
bn_mult_k(&xr, 3, &curve->prime); bn_mult_k(&xr, 3, &curve->prime);
bn_addmod(&xr, &curve->a, &curve->prime); bn_subi(&xr, -curve->a, &curve->prime);
bn_multiply(&xr, &lambda, &curve->prime); bn_multiply(&xr, &lambda, &curve->prime);
// xr = lambda^2 - 2*x // xr = lambda^2 - 2*x
@ -177,13 +177,15 @@ void conditional_negate(uint32_t cond, bignum256 *a, const bignum256 *prime)
{ {
int j; int j;
uint32_t tmp = 1; uint32_t tmp = 1;
assert(a->val[8] < 0x20000);
for (j = 0; j < 8; j++) { for (j = 0; j < 8; j++) {
tmp += 0x3fffffff + prime->val[j] - a->val[j]; tmp += 0x3fffffff + 2*prime->val[j] - a->val[j];
a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond); a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond);
tmp >>= 30; tmp >>= 30;
} }
tmp += 0x3fffffff + prime->val[j] - a->val[j]; tmp += 0x3fffffff + 2*prime->val[j] - a->val[j];
a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond); a->val[j] = ((tmp & 0x3fffffff) & cond) | (a->val[j] & ~cond);
assert(a->val[8] < 0x20000);
} }
typedef struct jacobian_curve_point { typedef struct jacobian_curve_point {
@ -207,13 +209,10 @@ void curve_to_jacobian(const curve_point *p, jacobian_curve_point *jp, const big
bn_multiply(&p->x, &jp->x, prime); bn_multiply(&p->x, &jp->x, prime);
bn_multiply(&p->y, &jp->y, prime); bn_multiply(&p->y, &jp->y, prime);
bn_mod(&jp->x, prime);
bn_mod(&jp->y, prime);
} }
void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const bignum256 *prime) { void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const bignum256 *prime) {
p->y = jp->z; p->y = jp->z;
bn_mod(&p->y, prime);
bn_inverse(&p->y, prime); bn_inverse(&p->y, prime);
// p->y = z^-1 // p->y = z^-1
p->x = p->y; p->x = p->y;
@ -229,90 +228,119 @@ void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const big
bn_mod(&p->y, prime); bn_mod(&p->y, prime);
} }
void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const bignum256 *prime) { void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const ecdsa_curve *curve) {
bignum256 r, h; bignum256 r, h, r2;
bignum256 rsq, hcb, hcby2, hsqx2; bignum256 hcby, hsqx;
bignum256 xz, yz, az;
int is_doubling;
const bignum256 *prime = &curve->prime;
int a = curve->a;
/* usual algorithm: assert (-3 <= a && a <= 0);
/* First we bring p1 to the same denominator:
* x1' := x1 * z2^2
* y1' := y1 * z2^3
*/
/*
* lambda = ((y1' - y2)/z2^3) / ((x1' - x2)/z2^2)
* = (y1' - y2) / (x1' - x2) z2
* x3/z3^2 = lambda^2 - (x1' + x2)/z2^2
* y3/z3^3 = 1/2 lambda * (2x3/z3^2 - (x1' + x2)/z2^2) + (y1'+y2)/z2^3
* *
* lambda = (y1 - y2/z2^3) / (x1 - x2/z2^2) * For the special case x1=x2, y1=y2 (doubling) we have
* x3/z3^2 = lambda^2 - x1 - x2/z2^2 * lambda = 3/2 ((x2/z2^2)^2 + a) / (y2/z2^3)
* y3/z3^3 = lambda * (x2/z2^2 - x3/z3^2) - y2/z2^3 * = 3/2 (x2^2 + a*z2^4) / y2*z2)
* *
* to get rid of fraction we set * to get rid of fraction we write lambda as
* r = (y1 * z2^3 - y2) (the numerator of lambda * z2^3) * lambda = r / (h*z2)
* h = (x1 * z2^2 - x2) (the denominator of lambda * z2^2) * with r = is_doubling ? 3/2 x2^2 + az2^4 : (y1 - y2)
* Hence, * h = is_doubling ? y1+y2 : (x1 - x2)
* lambda = r / (h*z2)
* *
* With z3 = h*z2 (the denominator of lambda) * With z3 = h*z2 (the denominator of lambda)
* we get x3 = lambda^2*z3^2 - x1*z3^2 - x2/z2^2*z3^2 * we get x3 = lambda^2*z3^2 - (x1' + x2)/z2^2*z3^2
* = r^2 - x1*h^2*z2^2 - x2*h^2 * = r^2 - h^2 * (x1' + x2)
* = r^2 - h^2*(x1*z2^2 + x2) * and y3 = 1/2 r * (2x3 - h^2*(x1' + x2)) + h^3*(y1' + y2)
* = r^2 - h^2*(h + 2*x2)
* = r^2 - h^3 - 2*h^2*x2
* and y3 = (lambda * (x2/z2^2 - x3/z3^2) - y2/z2^3) * z3^3
* = r * (h^2*x2 - x3) - h^3*y2
*/ */
/* h = x1*z2^2 - x2 /* h = x1 - x2
* r = y1*z2^3 - y2 * r = y1 - y2
* x3 = r^2 - h^3 - 2*h^2*x2 * x3 = r^2 - h^3 - 2*h^2*x2
* y3 = r*(h^2*x2 - x3) - h^3*y2 * y3 = r*(h^2*x2 - x3) - h^3*y2
* z3 = h*z2 * z3 = h*z2
*/ */
// h = x1 * z2^2 - x2; xz = p2->z;
// r = y1 * z2^3 - y2; bn_multiply(&xz, &xz, prime); // xz = z2^2
h = p2->z; yz = p2->z;
bn_multiply(&h, &h, prime); // h = z2^2 bn_multiply(&xz, &yz, prime); // yz = z2^3
r = p2->z;
bn_multiply(&h, &r, prime); // r = z2^3 if (a != 0) {
az = xz;
bn_multiply(&p1->x, &h, prime); bn_multiply(&az, &az, prime); // az = z2^4
bn_mult_k(&az, -a, prime); // az = -az2^4
}
bn_multiply(&p1->x, &xz, prime); // xz = x1' = x1*z2^2;
h = xz;
bn_subtractmod(&h, &p2->x, &h, prime); bn_subtractmod(&h, &p2->x, &h, prime);
// h = x1 * z2^2 - x2; bn_fast_mod(&h, prime);
// h = x1' - x2;
bn_multiply(&p1->y, &r, prime); bn_addmod(&xz, &p2->x, prime);
bn_subtractmod(&r, &p2->y, &r, prime); // xz = x1' + x2
// r = y1 * z2^3 - y2;
// hsqx2 = h^2 is_doubling = bn_is_zero(&h) | bn_is_equal(&h, prime);
hsqx2 = h;
bn_multiply(&hsqx2, &hsqx2, prime);
// hcb = h^3 bn_multiply(&p1->y, &yz, prime); // yz = y1' = y1*z2^3;
hcb = h; bn_subtractmod(&yz, &p2->y, &r, prime);
bn_multiply(&hsqx2, &hcb, prime); // r = y1' - y2;
// hsqx2 = h^2 * x2 bn_addmod(&yz, &p2->y, prime);
bn_multiply(&p2->x, &hsqx2, prime); // yz = y1' + y2
// hcby2 = h^3 * y2 r2 = p2->x;
hcby2 = hcb; bn_multiply(&r2, &r2, prime);
bn_multiply(&p2->y, &hcby2, prime); bn_mult_k(&r2, 3, prime);
if (a != 0) {
// subtract -a z2^4, i.e, add a z2^4
bn_subtractmod(&r2, &az, &r2, prime);
}
bn_cmov(&r, is_doubling, &r2, &r);
bn_cmov(&h, is_doubling, &yz, &h);
// rsq = r^2 // hsqx = h^2
rsq = r; hsqx = h;
bn_multiply(&rsq, &rsq, prime); bn_multiply(&hsqx, &hsqx, prime);
// hcby = h^3
hcby = h;
bn_multiply(&hsqx, &hcby, prime);
// hsqx = h^2 * (x1 + x2)
bn_multiply(&xz, &hsqx, prime);
// hcby = h^3 * (y1 + y2)
bn_multiply(&yz, &hcby, prime);
// z3 = h*z2 // z3 = h*z2
bn_multiply(&h, &p2->z, prime); bn_multiply(&h, &p2->z, prime);
bn_mod(&p2->z, prime);
// x3 = r^2 - h^3 - 2h^2x2 // x3 = r^2 - h^2 (x1 + x2)
bn_addmod(&hcb, &hsqx2, prime); p2->x = r;
bn_addmod(&hcb, &hsqx2, prime); bn_multiply(&p2->x, &p2->x, prime);
bn_subtractmod(&rsq, &hcb, &p2->x, prime); bn_subtractmod(&p2->x, &hsqx, &p2->x, prime);
bn_fast_mod(&p2->x, prime); bn_fast_mod(&p2->x, prime);
bn_mod(&p2->x, prime);
// y3 = r*(h^2x2 - x3) - y2*h^3 // y3 = 1/2 (r*(h^2 (x1 + x2) - 2x3) - h^3 (y1 + y2))
bn_subtractmod(&hsqx2, &p2->x, &p2->y, prime); bn_subtractmod(&hsqx, &p2->x, &p2->y, prime);
bn_subtractmod(&p2->y, &p2->x, &p2->y, prime);
bn_multiply(&r, &p2->y, prime); bn_multiply(&r, &p2->y, prime);
bn_subtractmod(&p2->y, &hcby2, &p2->y, prime); bn_subtractmod(&p2->y, &hcby, &p2->y, prime);
bn_mult_half(&p2->y, prime);
bn_fast_mod(&p2->y, prime); bn_fast_mod(&p2->y, prime);
bn_mod(&p2->y, prime);
} }
void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) { void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
@ -350,8 +378,8 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
az4 = p->z; az4 = p->z;
bn_multiply(&az4, &az4, prime); bn_multiply(&az4, &az4, prime);
bn_multiply(&az4, &az4, prime); bn_multiply(&az4, &az4, prime);
bn_multiply(&curve->a, &az4, prime); bn_mult_k(&az4, -curve->a, prime);
bn_addmod(&m, &az4, prime); bn_subtractmod(&m, &az4, &m, prime);
bn_mult_half(&m, prime); bn_mult_half(&m, prime);
// msq = m^2 // msq = m^2
@ -366,15 +394,13 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
// z3 = yz // z3 = yz
bn_multiply(&p->y, &p->z, prime); bn_multiply(&p->y, &p->z, prime);
bn_mod(&p->z, prime);
// x3 = m^2 - 2*xy^2 // x3 = m^2 - 2*xy^2
p->x = xysq; p->x = xysq;
bn_mod(&p->x, prime);
bn_lshift(&p->x); bn_lshift(&p->x);
bn_fast_mod(&p->x, prime);
bn_subtractmod(&msq, &p->x, &p->x, prime); bn_subtractmod(&msq, &p->x, &p->x, prime);
bn_fast_mod(&p->x, prime); bn_fast_mod(&p->x, prime);
bn_mod(&p->x, prime);
// y3 = m*(xy^2 - x3) - y^4 // y3 = m*(xy^2 - x3) - y^4
bn_subtractmod(&xysq, &p->x, &p->y, prime); bn_subtractmod(&xysq, &p->x, &p->y, prime);
@ -382,7 +408,6 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
bn_multiply(&ysq, &ysq, prime); bn_multiply(&ysq, &ysq, prime);
bn_subtractmod(&p->y, &ysq, &p->y, prime); bn_subtractmod(&p->y, &ysq, &p->y, prime);
bn_fast_mod(&p->y, prime); bn_fast_mod(&p->y, prime);
bn_mod(&p->y, prime);
} }
// res = k * p // res = k * p
@ -482,7 +507,7 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k, const curve_po
conditional_negate(sign ^ nsign, &jres.z, prime); conditional_negate(sign ^ nsign, &jres.z, prime);
// add odd factor // add odd factor
point_jacobian_add(&pmult[bits >> 1], &jres, prime); point_jacobian_add(&pmult[bits >> 1], &jres, curve);
sign = nsign; sign = nsign;
} }
conditional_negate(sign, &jres.z, prime); conditional_negate(sign, &jres.z, prime);
@ -569,7 +594,7 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k, curve_point *
conditional_negate((lowbits & 1) - 1, &jres.y, prime); conditional_negate((lowbits & 1) - 1, &jres.y, prime);
// add odd factor // add odd factor
point_jacobian_add(&curve->cp[i][lowbits >> 1], &jres, prime); point_jacobian_add(&curve->cp[i][lowbits >> 1], &jres, curve);
} }
conditional_negate(((a.val[0] >> 4) & 1) - 1, &jres.y, prime); conditional_negate(((a.val[0] >> 4) & 1) - 1, &jres.y, prime);
jacobian_to_curve(&jres, res, prime); jacobian_to_curve(&jres, res, prime);
@ -832,10 +857,11 @@ int ecdsa_address_decode(const char *addr, uint8_t *out)
void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x, bignum256 *y) void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x, bignum256 *y)
{ {
// y^2 = x^3 + 0*x + 7 // y^2 = x^3 + 0*x + 7
memcpy(y, x, sizeof(bignum256)); // y is x memcpy(y, x, sizeof(bignum256)); // y is x
bn_multiply(x, y, &curve->prime); // y is x^2 bn_multiply(x, y, &curve->prime); // y is x^2
bn_multiply(x, y, &curve->prime); // y is x^3 bn_subi(y, -curve->a, &curve->prime); // y is x^2 + a
bn_addmodi(y, 7, &curve->prime); // y is x^3 + 7 bn_multiply(x, y, &curve->prime); // y is x^3 + ax
bn_addmod(y, &curve->b, &curve->prime); // y is x^3 + ax + b
bn_sqrt(y, &curve->prime); // y = sqrt(y) bn_sqrt(y, &curve->prime); // y = sqrt(y)
if ((odd & 0x01) != (y->val[0] & 1)) { if ((odd & 0x01) != (y->val[0] & 1)) {
bn_subtract(&curve->prime, y, y); // y = -y bn_subtract(&curve->prime, y, y); // y = -y
@ -882,10 +908,11 @@ int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub)
bn_multiply(&(pub->y), &y_2, &curve->prime); bn_multiply(&(pub->y), &y_2, &curve->prime);
bn_mod(&y_2, &curve->prime); bn_mod(&y_2, &curve->prime);
// x^3 + b // x^3 + ax + b
bn_multiply(&(pub->x), &x_3_b, &curve->prime); bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^2
bn_multiply(&(pub->x), &x_3_b, &curve->prime); bn_subi(&x_3_b, -curve->a, &curve->prime); // x^2 + a
bn_addmodi(&x_3_b, 7, &curve->prime); bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^3 + ax
bn_addmod(&x_3_b, &curve->b, &curve->prime); // x^3 + ax + b
if (!bn_is_equal(&x_3_b, &y_2)) { if (!bn_is_equal(&x_3_b, &y_2)) {
return 0; return 0;

View File

@ -39,7 +39,7 @@ typedef struct {
curve_point G; // initial curve point curve_point G; // initial curve point
bignum256 order; // order of G bignum256 order; // order of G
bignum256 order_half; // order of G divided by 2 bignum256 order_half; // order of G divided by 2
bignum256 a; // coefficient 'a' of the elliptic curve int a; // coefficient 'a' of the elliptic curve
bignum256 b; // coefficient 'b' of the elliptic curve bignum256 b; // coefficient 'b' of the elliptic curve
#if USE_PRECOMPUTED_CP #if USE_PRECOMPUTED_CP
@ -68,6 +68,7 @@ void ecdsa_get_pubkeyhash(const uint8_t *pub_key, uint8_t *pubkeyhash);
void ecdsa_get_address_raw(const uint8_t *pub_key, uint8_t version, uint8_t *addr_raw); void ecdsa_get_address_raw(const uint8_t *pub_key, uint8_t version, uint8_t *addr_raw);
void ecdsa_get_address(const uint8_t *pub_key, uint8_t version, char *addr, int addrsize); void ecdsa_get_address(const uint8_t *pub_key, uint8_t version, char *addr, int addrsize);
void ecdsa_get_wif(const uint8_t *priv_key, uint8_t version, char *wif, int wifsize); void ecdsa_get_wif(const uint8_t *priv_key, uint8_t version, char *wif, int wifsize);
int ecdsa_address_decode(const char *addr, uint8_t *out); int ecdsa_address_decode(const char *addr, uint8_t *out);
int ecdsa_read_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key, curve_point *pub); int ecdsa_read_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key, curve_point *pub);
int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub); int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub);

View File

@ -41,9 +41,7 @@ const ecdsa_curve nist256p1 = {
/*.val =*/{0x3e3192a8, 0x27739585, 0x38bcf427, 0x1cdf55b4, 0x3fffffde, 0x3fffffff, 0x7ff, 0x3fffe000, 0x7fff} /*.val =*/{0x3e3192a8, 0x27739585, 0x38bcf427, 0x1cdf55b4, 0x3fffffde, 0x3fffffff, 0x7ff, 0x3fffe000, 0x7fff}
}, },
/* a */ { /* a */ -3,
/*.val =*/{0x3ffffffc, 0x3fffffff, 0x3fffffff, 0x3f, 0x0, 0x0, 0x1000, 0x3fffc000, 0xffff}
},
/* b */ { /* b */ {
/*.val =*/{0x27d2604b, 0x2f38f0f8, 0x53b0f63, 0x741ac33, 0x1886bc65, 0x2ef555da, 0x293e7b3e, 0xd762a8e, 0x5ac6} /*.val =*/{0x27d2604b, 0x2f38f0f8, 0x53b0f63, 0x741ac33, 0x1886bc65, 0x2ef555da, 0x293e7b3e, 0xd762a8e, 0x5ac6}

View File

@ -41,9 +41,7 @@ const ecdsa_curve secp256k1 = {
/*.val =*/{0x281b20a0, 0x3fa4bd19, 0x3a4501dd, 0x15db9cd5, 0x3fffff5d, 0x3fffffff, 0x3fffffff, 0x3fffffff, 0x7fff} /*.val =*/{0x281b20a0, 0x3fa4bd19, 0x3a4501dd, 0x15db9cd5, 0x3fffff5d, 0x3fffffff, 0x3fffffff, 0x3fffffff, 0x7fff}
}, },
/* a */ { /* a */ 0,
/*.val =*/{0}
},
/* b */ { /* b */ {
/*.val =*/{7} /*.val =*/{7}

View File

@ -199,7 +199,7 @@ def test_multiply2(curve, r):
prime = int2bn(curve.p) prime = int2bn(curve.p)
lib.bn_multiply_reduce(x, res, prime) lib.bn_multiply_reduce(x, res, prime)
x = bn2int(x) x = bn2int(x) % curve.p
x_ = s % curve.p x_ = s % curve.p
assert x == x_ assert x == x_
@ -263,17 +263,17 @@ def test_point_double(curve, r):
def test_point_to_jacobian(curve, r): def test_point_to_jacobian(curve, r):
p = r.randpoint(curve) p = r.randpoint(curve)
jp = JACOBIAN() jp = JACOBIAN()
lib.curve_to_jacobian(to_POINT(p), jp, int2bn(curve.p)) lib.curve_to_jacobian(to_POINT(p), jp, int2bn(curve.p))
jx, jy, jz = from_JACOBIAN(jp) jx, jy, jz = from_JACOBIAN(jp)
assert jx == (p.x() * jz ** 2) % curve.p assert jx % curve.p == (p.x() * jz ** 2) % curve.p
assert jy == (p.y() * jz ** 3) % curve.p assert jy % curve.p == (p.y() * jz ** 3) % curve.p
q = POINT() q = POINT()
lib.jacobian_to_curve(jp, q, int2bn(curve.p)) lib.jacobian_to_curve(jp, q, int2bn(curve.p))
q = from_POINT(q) q = from_POINT(q)
assert q == (p.x(), p.y()) assert q == (p.x(), p.y())
def test_cond_negate(curve, r): def test_cond_negate(curve, r):
@ -282,7 +282,7 @@ def test_cond_negate(curve, r):
lib.conditional_negate(0, a, int2bn(curve.p)) lib.conditional_negate(0, a, int2bn(curve.p))
assert bn2int(a) == x assert bn2int(a) == x
lib.conditional_negate(-1, a, int2bn(curve.p)) lib.conditional_negate(-1, a, int2bn(curve.p))
assert bn2int(a) == curve.p - x assert bn2int(a) == 2*curve.p - x
def test_jacobian_add(curve, r): def test_jacobian_add(curve, r):
@ -292,7 +292,20 @@ def test_jacobian_add(curve, r):
q = POINT() q = POINT()
jp2 = JACOBIAN() jp2 = JACOBIAN()
lib.curve_to_jacobian(to_POINT(p2), jp2, prime) lib.curve_to_jacobian(to_POINT(p2), jp2, prime)
lib.point_jacobian_add(to_POINT(p1), jp2, prime) lib.point_jacobian_add(to_POINT(p1), jp2, curve.ptr)
lib.jacobian_to_curve(jp2, q, prime)
q = from_POINT(q)
p_ = p1 + p2
assert (p_.x(), p_.y()) == q
def test_jacobian_add_double(curve, r):
p1 = r.randpoint(curve)
p2 = p1
prime = int2bn(curve.p)
q = POINT()
jp2 = JACOBIAN()
lib.curve_to_jacobian(to_POINT(p2), jp2, prime)
lib.point_jacobian_add(to_POINT(p1), jp2, curve.ptr)
lib.jacobian_to_curve(jp2, q, prime) lib.jacobian_to_curve(jp2, q, prime)
q = from_POINT(q) q = from_POINT(q)
p_ = p1 + p2 p_ = p1 + p2
@ -330,3 +343,7 @@ def test_sign(curve, r):
assert binascii.hexlify(sig) == binascii.hexlify(sig_ref) assert binascii.hexlify(sig) == binascii.hexlify(sig_ref)
assert vk.verify_digest(sig, digest, sigdecode) assert vk.verify_digest(sig, digest, sigdecode)
def test_validate_pubkey(curve, r):
p = r.randpoint(curve)
assert lib.ecdsa_validate_pubkey(curve.ptr, to_POINT(p))

218
tests.c
View File

@ -38,12 +38,7 @@
#include "sha2.h" #include "sha2.h"
#include "options.h" #include "options.h"
#include "secp256k1.h" #include "secp256k1.h"
#include "nist256p1.h"
#define CURVE (&secp256k1)
#define prime256k1 (secp256k1.prime)
#define G256k1 (secp256k1.G)
#define order256k1 (secp256k1.order)
#define secp256k1_cp (secp256k1.cp)
uint8_t *fromhex(const char *str) uint8_t *fromhex(const char *str)
{ {
@ -509,7 +504,7 @@ END_TEST
#define test_deterministic(KEY, MSG, K) do { \ #define test_deterministic(KEY, MSG, K) do { \
sha256_Raw((uint8_t *)MSG, strlen(MSG), buf); \ sha256_Raw((uint8_t *)MSG, strlen(MSG), buf); \
res = generate_k_rfc6979(CURVE, &k, fromhex(KEY), buf); \ res = generate_k_rfc6979(curve, &k, fromhex(KEY), buf); \
ck_assert_int_eq(res, 0); \ ck_assert_int_eq(res, 0); \
bn_write_be(&k, buf); \ bn_write_be(&k, buf); \
ck_assert_mem_eq(buf, fromhex(K), 32); \ ck_assert_mem_eq(buf, fromhex(K), 32); \
@ -520,6 +515,7 @@ START_TEST(test_rfc6979)
int res; int res;
bignum256 k; bignum256 k;
uint8_t buf[32]; uint8_t buf[32];
const ecdsa_curve *curve = &secp256k1;
test_deterministic("cca9fbcc1b41e5a95d369eaa6ddcff73b61a4efaa279cfc6567e8daa39cbaf50", "sample", "2df40ca70e639d89528a6b670d9d48d9165fdc0febc0974056bdce192b8e16a3"); test_deterministic("cca9fbcc1b41e5a95d369eaa6ddcff73b61a4efaa279cfc6567e8daa39cbaf50", "sample", "2df40ca70e639d89528a6b670d9d48d9165fdc0febc0974056bdce192b8e16a3");
test_deterministic("0000000000000000000000000000000000000000000000000000000000000001", "Satoshi Nakamoto", "8f8a276c19f4149656b280621e358cce24f5f52542772691ee69063b74f15d15"); test_deterministic("0000000000000000000000000000000000000000000000000000000000000001", "Satoshi Nakamoto", "8f8a276c19f4149656b280621e358cce24f5f52542772691ee69063b74f15d15");
@ -535,6 +531,7 @@ START_TEST(test_sign_speed)
uint8_t sig[64], priv_key[32], msg[256]; uint8_t sig[64], priv_key[32], msg[256];
size_t i; size_t i;
int res; int res;
const ecdsa_curve *curve = &secp256k1;
for (i = 0; i < sizeof(msg); i++) { for (i = 0; i < sizeof(msg); i++) {
msg[i] = i * 1103515245; msg[i] = i * 1103515245;
@ -544,13 +541,13 @@ START_TEST(test_sign_speed)
memcpy(priv_key, fromhex("c55ece858b0ddd5263f96810fe14437cd3b5e1fbd7c6a2ec1e031f05e86d8bd5"), 32); memcpy(priv_key, fromhex("c55ece858b0ddd5263f96810fe14437cd3b5e1fbd7c6a2ec1e031f05e86d8bd5"), 32);
for (i = 0 ; i < 250; i++) { for (i = 0 ; i < 250; i++) {
res = ecdsa_sign(CURVE, priv_key, msg, sizeof(msg), sig, 0); res = ecdsa_sign(curve, priv_key, msg, sizeof(msg), sig, 0);
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
} }
memcpy(priv_key, fromhex("509a0382ff5da48e402967a671bdcde70046d07f0df52cff12e8e3883b426a0a"), 32); memcpy(priv_key, fromhex("509a0382ff5da48e402967a671bdcde70046d07f0df52cff12e8e3883b426a0a"), 32);
for (i = 0 ; i < 250; i++) { for (i = 0 ; i < 250; i++) {
res = ecdsa_sign(CURVE, priv_key, msg, sizeof(msg), sig, 0); res = ecdsa_sign(curve, priv_key, msg, sizeof(msg), sig, 0);
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
} }
@ -563,6 +560,7 @@ START_TEST(test_verify_speed)
uint8_t sig[64], pub_key33[33], pub_key65[65], msg[256]; uint8_t sig[64], pub_key33[33], pub_key65[65], msg[256];
size_t i; size_t i;
int res; int res;
const ecdsa_curve *curve = &secp256k1;
for (i = 0; i < sizeof(msg); i++) { for (i = 0; i < sizeof(msg); i++) {
msg[i] = i * 1103515245; msg[i] = i * 1103515245;
@ -575,9 +573,9 @@ START_TEST(test_verify_speed)
memcpy(pub_key65, fromhex("044054fd18aeb277aeedea01d3f3986ff4e5be18092a04339dcf4e524e2c0a09746c7083ed2097011b1223a17a644e81f59aa3de22dac119fd980b36a8ff29a244"), 65); memcpy(pub_key65, fromhex("044054fd18aeb277aeedea01d3f3986ff4e5be18092a04339dcf4e524e2c0a09746c7083ed2097011b1223a17a644e81f59aa3de22dac119fd980b36a8ff29a244"), 65);
for (i = 0 ; i < 25; i++) { for (i = 0 ; i < 25; i++) {
res = ecdsa_verify(CURVE, pub_key65, sig, msg, sizeof(msg)); res = ecdsa_verify(curve, pub_key65, sig, msg, sizeof(msg));
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
res = ecdsa_verify(CURVE, pub_key33, sig, msg, sizeof(msg)); res = ecdsa_verify(curve, pub_key33, sig, msg, sizeof(msg));
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
} }
@ -586,9 +584,9 @@ START_TEST(test_verify_speed)
memcpy(pub_key65, fromhex("04ff45a5561a76be930358457d113f25fac790794ec70317eff3b97d7080d457196235193a15778062ddaa44aef7e6901b781763e52147f2504e268b2d572bf197"), 65); memcpy(pub_key65, fromhex("04ff45a5561a76be930358457d113f25fac790794ec70317eff3b97d7080d457196235193a15778062ddaa44aef7e6901b781763e52147f2504e268b2d572bf197"), 65);
for (i = 0 ; i < 25; i++) { for (i = 0 ; i < 25; i++) {
res = ecdsa_verify(CURVE, pub_key65, sig, msg, sizeof(msg)); res = ecdsa_verify(curve, pub_key65, sig, msg, sizeof(msg));
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
res = ecdsa_verify(CURVE, pub_key33, sig, msg, sizeof(msg)); res = ecdsa_verify(curve, pub_key33, sig, msg, sizeof(msg));
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
} }
@ -1039,45 +1037,46 @@ START_TEST(test_pubkey_validity)
uint8_t pub_key[65]; uint8_t pub_key[65];
curve_point pub; curve_point pub;
int res; int res;
const ecdsa_curve *curve = &secp256k1;
memcpy(pub_key, fromhex("0226659c1cf7321c178c07437150639ff0c5b7679c7ea195253ed9abda2e081a37"), 33); memcpy(pub_key, fromhex("0226659c1cf7321c178c07437150639ff0c5b7679c7ea195253ed9abda2e081a37"), 33);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("025b1654a0e78d28810094f6c5a96b8efb8a65668b578f170ac2b1f83bc63ba856"), 33); memcpy(pub_key, fromhex("025b1654a0e78d28810094f6c5a96b8efb8a65668b578f170ac2b1f83bc63ba856"), 33);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("03433f246a12e6486a51ff08802228c61cf895175a9b49ed4766ea9a9294a3c7fe"), 33); memcpy(pub_key, fromhex("03433f246a12e6486a51ff08802228c61cf895175a9b49ed4766ea9a9294a3c7fe"), 33);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("03aeb03abeee0f0f8b4f7a5d65ce31f9570cef9f72c2dd8a19b4085a30ab033d48"), 33); memcpy(pub_key, fromhex("03aeb03abeee0f0f8b4f7a5d65ce31f9570cef9f72c2dd8a19b4085a30ab033d48"), 33);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("0496e8f2093f018aff6c2e2da5201ee528e2c8accbf9cac51563d33a7bb74a016054201c025e2a5d96b1629b95194e806c63eb96facaedc733b1a4b70ab3b33e3a"), 65); memcpy(pub_key, fromhex("0496e8f2093f018aff6c2e2da5201ee528e2c8accbf9cac51563d33a7bb74a016054201c025e2a5d96b1629b95194e806c63eb96facaedc733b1a4b70ab3b33e3a"), 65);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("0498010f8a687439ff497d3074beb4519754e72c4b6220fb669224749591dde416f3961f8ece18f8689bb32235e436874d2174048b86118a00afbd5a4f33a24f0f"), 65); memcpy(pub_key, fromhex("0498010f8a687439ff497d3074beb4519754e72c4b6220fb669224749591dde416f3961f8ece18f8689bb32235e436874d2174048b86118a00afbd5a4f33a24f0f"), 65);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b826b8f40b92719086238e1bebf19b77a4d"), 65); memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b826b8f40b92719086238e1bebf19b77a4d"), 65);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 1); ck_assert_int_eq(res, 1);
memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b826b8f40b92719086238e1bebf00000000"), 65); memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b826b8f40b92719086238e1bebf00000000"), 65);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b8211111111111111111111111111111111"), 65); memcpy(pub_key, fromhex("04f80490839af36d13701ec3f9eebdac901b51c362119d74553a3c537faff31b17e2a59ebddbdac9e87b816307a7ed5b8211111111111111111111111111111111"), 65);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
memcpy(pub_key, fromhex("00"), 1); memcpy(pub_key, fromhex("00"), 1);
res = ecdsa_read_pubkey(CURVE, pub_key, &pub); res = ecdsa_read_pubkey(curve, pub_key, &pub);
ck_assert_int_eq(res, 0); ck_assert_int_eq(res, 0);
} }
END_TEST END_TEST
@ -1210,7 +1209,7 @@ START_TEST(test_ecdsa_der)
} }
END_TEST END_TEST
START_TEST(test_secp256k1_cp) { static void test_codepoints_curve(const ecdsa_curve *curve) {
int i, j; int i, j;
bignum256 a; bignum256 a;
curve_point p, p1; curve_point p, p1;
@ -1221,108 +1220,130 @@ START_TEST(test_secp256k1_cp) {
bn_normalize(&a); bn_normalize(&a);
// note that this is not a trivial test. We add 64 curve // note that this is not a trivial test. We add 64 curve
// points in the table to get that particular curve point. // points in the table to get that particular curve point.
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &secp256k1_cp[i][j], sizeof(curve_point)); ck_assert_mem_eq(&p, &curve->cp[i][j], sizeof(curve_point));
bn_zero(&p.y); // test that point_multiply CURVE, is not a noop bn_zero(&p.y); // test that point_multiply curve, is not a noop
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &secp256k1_cp[i][j], sizeof(curve_point)); ck_assert_mem_eq(&p, &curve->cp[i][j], sizeof(curve_point));
// mul 2 test. this should catch bugs
// even/odd has different behaviour; bn_lshift(&a);
// increment by one and test again bn_mod(&a, &curve->order);
p1 = p; p1 = curve->cp[i][j];
point_add(CURVE, &G256k1, &p1); point_double(curve, &p1);
bn_addmodi(&a, 1, &order256k1); // note that this is not a trivial test. We add 64 curve
scalar_multiply(CURVE, &a, &p); // points in the table to get that particular curve point.
scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &p1, sizeof(curve_point)); ck_assert_mem_eq(&p, &p1, sizeof(curve_point));
bn_zero(&p.y); // test that point_multiply CURVE, is not a noop bn_zero(&p.y); // test that point_multiply curve, is not a noop
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &p1, sizeof(curve_point)); ck_assert_mem_eq(&p, &p1, sizeof(curve_point));
} }
} }
} }
START_TEST(test_codepoints) {
test_codepoints_curve(&secp256k1);
test_codepoints_curve(&nist256p1);
}
END_TEST END_TEST
START_TEST(test_mult_border_cases) { static void test_mult_border_cases_curve(const ecdsa_curve *curve) {
bignum256 a; bignum256 a;
curve_point p; curve_point p;
curve_point expected; curve_point expected;
bn_zero(&a); // a == 0 bn_zero(&a); // a == 0
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert(point_is_infinity(&p)); ck_assert(point_is_infinity(&p));
point_multiply(CURVE, &a, &p, &p); point_multiply(curve, &a, &p, &p);
ck_assert(point_is_infinity(&p)); ck_assert(point_is_infinity(&p));
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert(point_is_infinity(&p)); ck_assert(point_is_infinity(&p));
bn_addmodi(&a, 1, &order256k1); // a == 1 bn_addi(&a, 1); // a == 1
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &G256k1, sizeof(curve_point)); ck_assert_mem_eq(&p, &curve->G, sizeof(curve_point));
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &G256k1, sizeof(curve_point)); ck_assert_mem_eq(&p, &curve->G, sizeof(curve_point));
bn_subtract(&order256k1, &a, &a); // a == -1 bn_subtract(&curve->order, &a, &a); // a == -1
expected = G256k1; expected = curve->G;
bn_subtract(&prime256k1, &expected.y, &expected.y); bn_subtract(&curve->prime, &expected.y, &expected.y);
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
bn_subtract(&order256k1, &a, &a); bn_subtract(&curve->order, &a, &a);
bn_addmodi(&a, 1, &order256k1); // a == 2 bn_addi(&a, 1); // a == 2
expected = G256k1; expected = curve->G;
point_add(CURVE, &expected, &expected); point_add(curve, &expected, &expected);
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
bn_subtract(&order256k1, &a, &a); // a == -2 bn_subtract(&curve->order, &a, &a); // a == -2
expected = G256k1; expected = curve->G;
point_add(CURVE, &expected, &expected); point_add(curve, &expected, &expected);
bn_subtract(&prime256k1, &expected.y, &expected.y); bn_subtract(&curve->prime, &expected.y, &expected.y);
scalar_multiply(CURVE, &a, &p); scalar_multiply(curve, &a, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
point_multiply(CURVE, &a, &G256k1, &p); point_multiply(curve, &a, &curve->G, &p);
ck_assert_mem_eq(&p, &expected, sizeof(curve_point)); ck_assert_mem_eq(&p, &expected, sizeof(curve_point));
} }
START_TEST(test_mult_border_cases) {
test_mult_border_cases_curve(&secp256k1);
test_mult_border_cases_curve(&nist256p1);
}
END_TEST END_TEST
START_TEST(test_scalar_mult) { static void test_scalar_mult_curve(const ecdsa_curve *curve) {
int i; int i;
// get two "random" numbers // get two "random" numbers
bignum256 a = G256k1.x; bignum256 a = curve->G.x;
bignum256 b = G256k1.y; bignum256 b = curve->G.y;
curve_point p1, p2, p3; curve_point p1, p2, p3;
for (i = 0; i < 1000; i++) { for (i = 0; i < 1000; i++) {
/* test distributivity: (a + b)G = aG + bG */ /* test distributivity: (a + b)G = aG + bG */
scalar_multiply(CURVE, &a, &p1); bn_mod(&a, &curve->order);
scalar_multiply(CURVE, &b, &p2); bn_mod(&b, &curve->order);
bn_addmod(&a, &b, &order256k1); scalar_multiply(curve, &a, &p1);
scalar_multiply(CURVE, &a, &p3); scalar_multiply(curve, &b, &p2);
point_add(CURVE, &p1, &p2); bn_addmod(&a, &b, &curve->order);
bn_mod(&a, &curve->order);
scalar_multiply(curve, &a, &p3);
point_add(curve, &p1, &p2);
ck_assert_mem_eq(&p2, &p3, sizeof(curve_point)); ck_assert_mem_eq(&p2, &p3, sizeof(curve_point));
// new "random" numbers // new "random" numbers
a = p3.x; a = p3.x;
b = p3.y; b = p3.y;
} }
} }
START_TEST(test_scalar_mult) {
test_scalar_mult_curve(&secp256k1);
test_scalar_mult_curve(&nist256p1);
}
END_TEST END_TEST
START_TEST(test_point_mult) { static void test_point_mult_curve(const ecdsa_curve *curve) {
int i; int i;
// get two "random" numbers and a "random" point // get two "random" numbers and a "random" point
bignum256 a = G256k1.x; bignum256 a = curve->G.x;
bignum256 b = G256k1.y; bignum256 b = curve->G.y;
curve_point p = G256k1; curve_point p = curve->G;
curve_point p1, p2, p3; curve_point p1, p2, p3;
for (i = 0; i < 200; i++) { for (i = 0; i < 200; i++) {
/* test distributivity: (a + b)P = aP + bP */ /* test distributivity: (a + b)P = aP + bP */
point_multiply(CURVE, &a, &p, &p1); bn_mod(&a, &curve->order);
point_multiply(CURVE, &b, &p, &p2); bn_mod(&b, &curve->order);
bn_addmod(&a, &b, &order256k1); point_multiply(curve, &a, &p, &p1);
point_multiply(CURVE, &a, &p, &p3); point_multiply(curve, &b, &p, &p2);
point_add(CURVE, &p1, &p2); bn_addmod(&a, &b, &curve->order);
bn_mod(&a, &curve->order);
point_multiply(curve, &a, &p, &p3);
point_add(curve, &p1, &p2);
ck_assert_mem_eq(&p2, &p3, sizeof(curve_point)); ck_assert_mem_eq(&p2, &p3, sizeof(curve_point));
// new "random" numbers and a "random" point // new "random" numbers and a "random" point
a = p1.x; a = p1.x;
@ -1330,28 +1351,36 @@ START_TEST(test_point_mult) {
p = p3; p = p3;
} }
} }
START_TEST(test_point_mult) {
test_point_mult_curve(&secp256k1);
test_point_mult_curve(&nist256p1);
}
END_TEST END_TEST
START_TEST(test_scalar_point_mult) { static void test_scalar_point_mult_curve(const ecdsa_curve *curve) {
int i; int i;
// get two "random" numbers // get two "random" numbers
bignum256 a = G256k1.x; bignum256 a = curve->G.x;
bignum256 b = G256k1.y; bignum256 b = curve->G.y;
curve_point p1, p2; curve_point p1, p2;
for (i = 0; i < 200; i++) { for (i = 0; i < 200; i++) {
/* test commutativity and associativity: /* test commutativity and associativity:
* a(bG) = (ab)G = b(aG) * a(bG) = (ab)G = b(aG)
*/ */
scalar_multiply(CURVE, &a, &p1); bn_mod(&a, &curve->order);
point_multiply(CURVE, &b, &p1, &p1); bn_mod(&b, &curve->order);
scalar_multiply(curve, &a, &p1);
point_multiply(curve, &b, &p1, &p1);
scalar_multiply(CURVE, &b, &p2); scalar_multiply(curve, &b, &p2);
point_multiply(CURVE, &a, &p2, &p2); point_multiply(curve, &a, &p2, &p2);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point)); ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));
bn_multiply(&a, &b, &order256k1); bn_multiply(&a, &b, &curve->order);
scalar_multiply(CURVE, &b, &p2); bn_mod(&b, &curve->order);
scalar_multiply(curve, &b, &p2);
ck_assert_mem_eq(&p1, &p2, sizeof(curve_point)); ck_assert_mem_eq(&p1, &p2, sizeof(curve_point));
@ -1360,6 +1389,11 @@ START_TEST(test_scalar_point_mult) {
b = p1.y; b = p1.y;
} }
} }
START_TEST(test_scalar_point_mult) {
test_scalar_point_mult_curve(&secp256k1);
test_scalar_point_mult_curve(&nist256p1);
}
END_TEST END_TEST
// define test suite and cases // define test suite and cases
@ -1423,8 +1457,8 @@ Suite *test_suite(void)
tcase_add_test(tc, test_pubkey_validity); tcase_add_test(tc, test_pubkey_validity);
suite_add_tcase(s, tc); suite_add_tcase(s, tc);
tc = tcase_create("secp256k1_cp"); tc = tcase_create("codepoints");
tcase_add_test(tc, test_secp256k1_cp); tcase_add_test(tc, test_codepoints);
suite_add_tcase(s, tc); suite_add_tcase(s, tc);
tc = tcase_create("mult_border_cases"); tc = tcase_create("mult_border_cases");