1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-13 19:18:56 +00:00

New jacobian_add that handles doubling.

Fix bug where jacobian_add is called with two identical points.
This commit is contained in:
Jochen Hoenicke 2015-08-05 19:32:20 +02:00
parent 60e36dac3b
commit f2081d88d8
6 changed files with 130 additions and 73 deletions

View File

@ -116,6 +116,19 @@ int bn_is_equal(const bignum256 *a, const bignum256 *b) {
return !result;
}
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase)
{
int i;
uint32_t tmask = (uint32_t) -cond;
uint32_t fmask = ~tmask;
assert (cond == 1 || cond == 0);
for (i = 0; i < 9; i++) {
res->val[i] = (truecase->val[i] & tmask) |
(falsecase->val[i] & fmask);
}
}
int bn_bitlen(const bignum256 *a) {
int i = 8, j;
while (i >= 0 && a->val[i] == 0) i--;
@ -688,6 +701,15 @@ void bn_addi(bignum256 *a, uint32_t b) {
bn_normalize(a);
}
void bn_subi(bignum256 *a, uint32_t b, const bignum256 *prime) {
int i;
for (i = 0; i < 9; i++) {
a->val[i] += prime->val[i];
}
a->val[0] -= b;
bn_fast_mod(a, prime);
}
// res = a - b mod prime. More exactly res = a + (2*prime - b).
// precondition: 0 <= b < 2*prime, 0 <= a < prime
// res < 3*prime

View File

@ -51,6 +51,8 @@ int bn_is_less(const bignum256 *a, const bignum256 *b);
int bn_is_equal(const bignum256 *a, const bignum256 *b);
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase);
int bn_bitlen(const bignum256 *a);
void bn_lshift(bignum256 *a);
@ -77,6 +79,8 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime);
void bn_addi(bignum256 *a, uint32_t b);
void bn_subi(bignum256 *a, uint32_t b, const bignum256 *prime);
void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime);
void bn_subtract(const bignum256 *a, const bignum256 *b, bignum256 *res);

166
ecdsa.c
View File

@ -111,7 +111,7 @@ void point_double(const ecdsa_curve *curve, curve_point *cp)
xr = cp->x;
bn_multiply(&xr, &xr, &curve->prime);
bn_mult_k(&xr, 3, &curve->prime);
bn_addmod(&xr, &curve->a, &curve->prime);
bn_subi(&xr, -curve->a, &curve->prime);
bn_multiply(&xr, &lambda, &curve->prime);
// xr = lambda^2 - 2*x
@ -228,86 +228,118 @@ void jacobian_to_curve(const jacobian_curve_point *jp, curve_point *p, const big
bn_mod(&p->y, prime);
}
void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const bignum256 *prime) {
bignum256 r, h;
bignum256 rsq, hcb, hcby2, hsqx2;
void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const ecdsa_curve *curve) {
bignum256 r, h, r2;
bignum256 hcby, hsqx;
bignum256 xz, yz, az;
int is_doubling;
const bignum256 *prime = &curve->prime;
int a = curve->a;
/* usual algorithm:
assert (-3 <= a && a <= 0);
/* First we bring p1 to the same denominator:
* x1' := x1 * z2^2
* y1' := y1 * z2^3
*/
/*
* lambda = ((y1' - y2)/z2^3) / ((x1' - x2)/z2^2)
* = (y1' - y2) / (x1' - x2) z2
* x3/z3^2 = lambda^2 - (x1' + x2)/z2^2
* y3/z3^3 = 1/2 lambda * (2x3/z3^2 - (x1' + x2)/z2^2) + (y1'+y2)/z2^3
*
* lambda = (y1 - y2/z2^3) / (x1 - x2/z2^2)
* x3/z3^2 = lambda^2 - x1 - x2/z2^2
* y3/z3^3 = lambda * (x2/z2^2 - x3/z3^2) - y2/z2^3
* For the special case x1=x2, y1=y2 (doubling) we have
* lambda = 3/2 ((x2/z2^2)^2 + a) / (y2/z2^3)
* = 3/2 (x2^2 + a*z2^4) / y2*z2)
*
* to get rid of fraction we set
* r = (y1 * z2^3 - y2) (the numerator of lambda * z2^3)
* h = (x1 * z2^2 - x2) (the denominator of lambda * z2^2)
* Hence,
* lambda = r / (h*z2)
* to get rid of fraction we write lambda as
* lambda = r / (h*z2)
* with r = is_doubling ? 3/2 x2^2 + az2^4 : (y1 - y2)
* h = is_doubling ? y1+y2 : (x1 - x2)
*
* With z3 = h*z2 (the denominator of lambda)
* we get x3 = lambda^2*z3^2 - x1*z3^2 - x2/z2^2*z3^2
* = r^2 - x1*h^2*z2^2 - x2*h^2
* = r^2 - h^2*(x1*z2^2 + x2)
* = r^2 - h^2*(h + 2*x2)
* = r^2 - h^3 - 2*h^2*x2
* and y3 = (lambda * (x2/z2^2 - x3/z3^2) - y2/z2^3) * z3^3
* = r * (h^2*x2 - x3) - h^3*y2
* we get x3 = lambda^2*z3^2 - (x1' + x2)/z2^2*z3^2
* = r^2 - h^2 * (x1' + x2)
* and y3 = 1/2 r * (2x3 - h^2*(x1' + x2)) + h^3*(y1' + y2)
*/
/* h = x1*z2^2 - x2
* r = y1*z2^3 - y2
/* h = x1 - x2
* r = y1 - y2
* x3 = r^2 - h^3 - 2*h^2*x2
* y3 = r*(h^2*x2 - x3) - h^3*y2
* z3 = h*z2
*/
// h = x1 * z2^2 - x2;
// r = y1 * z2^3 - y2;
h = p2->z;
bn_multiply(&h, &h, prime); // h = z2^2
r = p2->z;
bn_multiply(&h, &r, prime); // r = z2^3
bn_multiply(&p1->x, &h, prime);
xz = p2->z;
bn_multiply(&xz, &xz, prime); // xz = z2^2
yz = p2->z;
bn_multiply(&xz, &yz, prime); // yz = z2^3
if (a != 0) {
az = xz;
bn_multiply(&az, &az, prime); // az = z2^4
bn_mult_k(&az, -a, prime); // az = -az2^4
}
bn_multiply(&p1->x, &xz, prime); // xz = x1' = x1*z2^2;
h = xz;
bn_subtractmod(&h, &p2->x, &h, prime);
// h = x1 * z2^2 - x2;
bn_fast_mod(&h, prime);
// h = x1' - x2;
bn_multiply(&p1->y, &r, prime);
bn_subtractmod(&r, &p2->y, &r, prime);
// r = y1 * z2^3 - y2;
bn_addmod(&xz, &p2->x, prime);
// xz = x1' + x2
// hsqx2 = h^2
hsqx2 = h;
bn_multiply(&hsqx2, &hsqx2, prime);
is_doubling = bn_is_zero(&h) | bn_is_equal(&h, prime);
// hcb = h^3
hcb = h;
bn_multiply(&hsqx2, &hcb, prime);
bn_multiply(&p1->y, &yz, prime); // yz = y1' = y1*z2^3;
bn_subtractmod(&yz, &p2->y, &r, prime);
// r = y1' - y2;
// hsqx2 = h^2 * x2
bn_multiply(&p2->x, &hsqx2, prime);
bn_addmod(&yz, &p2->y, prime);
// yz = y1' + y2
// hcby2 = h^3 * y2
hcby2 = hcb;
bn_multiply(&p2->y, &hcby2, prime);
r2 = p2->x;
bn_multiply(&r2, &r2, prime);
bn_mult_k(&r2, 3, prime);
if (a != 0) {
// subtract -a z2^4, i.e, add a z2^4
bn_subtractmod(&r2, &az, &r2, prime);
}
bn_cmov(&r, is_doubling, &r2, &r);
bn_cmov(&h, is_doubling, &yz, &h);
// rsq = r^2
rsq = r;
bn_multiply(&rsq, &rsq, prime);
// hsqx = h^2
hsqx = h;
bn_multiply(&hsqx, &hsqx, prime);
// hcby = h^3
hcby = h;
bn_multiply(&hsqx, &hcby, prime);
// hsqx = h^2 * (x1 + x2)
bn_multiply(&xz, &hsqx, prime);
// hcby = h^3 * (y1 + y2)
bn_multiply(&yz, &hcby, prime);
// z3 = h*z2
bn_multiply(&h, &p2->z, prime);
// x3 = r^2 - h^3 - 2h^2x2
bn_addmod(&hcb, &hsqx2, prime);
bn_addmod(&hcb, &hsqx2, prime);
bn_subtractmod(&rsq, &hcb, &p2->x, prime);
// x3 = r^2 - h^2 (x1 + x2)
p2->x = r;
bn_multiply(&p2->x, &p2->x, prime);
bn_subtractmod(&p2->x, &hsqx, &p2->x, prime);
bn_fast_mod(&p2->x, prime);
// y3 = r*(h^2x2 - x3) - y2*h^3
bn_subtractmod(&hsqx2, &p2->x, &p2->y, prime);
// y3 = 1/2 (r*(h^2 (x1 + x2) - 2x3) - h^3 (y1 + y2))
bn_subtractmod(&hsqx, &p2->x, &p2->y, prime);
bn_subtractmod(&p2->y, &p2->x, &p2->y, prime);
bn_multiply(&r, &p2->y, prime);
bn_subtractmod(&p2->y, &hcby2, &p2->y, prime);
bn_subtractmod(&p2->y, &hcby, &p2->y, prime);
bn_mult_half(&p2->y, prime);
bn_fast_mod(&p2->y, prime);
}
@ -346,8 +378,8 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
az4 = p->z;
bn_multiply(&az4, &az4, prime);
bn_multiply(&az4, &az4, prime);
bn_multiply(&curve->a, &az4, prime);
bn_addmod(&m, &az4, prime);
bn_mult_k(&az4, -curve->a, prime);
bn_subtractmod(&m, &az4, &m, prime);
bn_mult_half(&m, prime);
// msq = m^2
@ -475,7 +507,7 @@ void point_multiply(const ecdsa_curve *curve, const bignum256 *k, const curve_po
conditional_negate(sign ^ nsign, &jres.z, prime);
// add odd factor
point_jacobian_add(&pmult[bits >> 1], &jres, prime);
point_jacobian_add(&pmult[bits >> 1], &jres, curve);
sign = nsign;
}
conditional_negate(sign, &jres.z, prime);
@ -562,7 +594,7 @@ void scalar_multiply(const ecdsa_curve *curve, const bignum256 *k, curve_point *
conditional_negate((lowbits & 1) - 1, &jres.y, prime);
// add odd factor
point_jacobian_add(&curve->cp[i][lowbits >> 1], &jres, prime);
point_jacobian_add(&curve->cp[i][lowbits >> 1], &jres, curve);
}
conditional_negate(((a.val[0] >> 4) & 1) - 1, &jres.y, prime);
jacobian_to_curve(&jres, res, prime);
@ -825,10 +857,11 @@ int ecdsa_address_decode(const char *addr, uint8_t *out)
void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x, bignum256 *y)
{
// y^2 = x^3 + 0*x + 7
memcpy(y, x, sizeof(bignum256)); // y is x
memcpy(y, x, sizeof(bignum256)); // y is x
bn_multiply(x, y, &curve->prime); // y is x^2
bn_multiply(x, y, &curve->prime); // y is x^3
bn_addi(y, 7); // y is x^3 + 7
bn_subi(y, -curve->a, &curve->prime); // y is x^2 + a
bn_multiply(x, y, &curve->prime); // y is x^3 + ax
bn_addmod(y, &curve->b, &curve->prime); // y is x^3 + ax + b
bn_sqrt(y, &curve->prime); // y = sqrt(y)
if ((odd & 0x01) != (y->val[0] & 1)) {
bn_subtract(&curve->prime, y, y); // y = -y
@ -875,10 +908,11 @@ int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub)
bn_multiply(&(pub->y), &y_2, &curve->prime);
bn_mod(&y_2, &curve->prime);
// x^3 + b
bn_multiply(&(pub->x), &x_3_b, &curve->prime);
bn_multiply(&(pub->x), &x_3_b, &curve->prime);
bn_addi(&x_3_b, 7);
// x^3 + ax + b
bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^2
bn_subi(&x_3_b, -curve->a, &curve->prime); // x^2 + a
bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^3 + ax
bn_addmod(&x_3_b, &curve->b, &curve->prime); // x^3 + ax + b
if (!bn_is_equal(&x_3_b, &y_2)) {
return 0;

View File

@ -39,7 +39,7 @@ typedef struct {
curve_point G; // initial curve point
bignum256 order; // order of G
bignum256 order_half; // order of G divided by 2
bignum256 a; // coefficient 'a' of the elliptic curve
int a; // coefficient 'a' of the elliptic curve
bignum256 b; // coefficient 'b' of the elliptic curve
#if USE_PRECOMPUTED_CP
@ -68,6 +68,7 @@ void ecdsa_get_pubkeyhash(const uint8_t *pub_key, uint8_t *pubkeyhash);
void ecdsa_get_address_raw(const uint8_t *pub_key, uint8_t version, uint8_t *addr_raw);
void ecdsa_get_address(const uint8_t *pub_key, uint8_t version, char *addr, int addrsize);
void ecdsa_get_wif(const uint8_t *priv_key, uint8_t version, char *wif, int wifsize);
int ecdsa_address_decode(const char *addr, uint8_t *out);
int ecdsa_read_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key, curve_point *pub);
int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub);

View File

@ -41,9 +41,7 @@ const ecdsa_curve nist256p1 = {
/*.val =*/{0x3e3192a8, 0x27739585, 0x38bcf427, 0x1cdf55b4, 0x3fffffde, 0x3fffffff, 0x7ff, 0x3fffe000, 0x7fff}
},
/* a */ {
/*.val =*/{0x3ffffffc, 0x3fffffff, 0x3fffffff, 0x3f, 0x0, 0x0, 0x1000, 0x3fffc000, 0xffff}
},
/* a */ -3,
/* b */ {
/*.val =*/{0x27d2604b, 0x2f38f0f8, 0x53b0f63, 0x741ac33, 0x1886bc65, 0x2ef555da, 0x293e7b3e, 0xd762a8e, 0x5ac6}

View File

@ -41,9 +41,7 @@ const ecdsa_curve secp256k1 = {
/*.val =*/{0x281b20a0, 0x3fa4bd19, 0x3a4501dd, 0x15db9cd5, 0x3fffff5d, 0x3fffffff, 0x3fffffff, 0x3fffffff, 0x7fff}
},
/* a */ {
/*.val =*/{0}
},
/* a */ 0,
/* b */ {
/*.val =*/{7}