1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

Refactored code for point doubling.

New function `bn_mult_3_2` that multiplies by 3/2.
This function is used in point_double and point_jacobian_double.
Cleaned up point_add and point_double, more comments.
This commit is contained in:
Jochen Hoenicke 2015-03-22 17:55:01 +01:00
parent edf0fc4902
commit 56f5777b68
3 changed files with 56 additions and 39 deletions

View File

@ -139,6 +139,26 @@ void bn_rshift(bignum256 *a)
a->val[8] >>= 1;
}
// multiply x by 3/2 modulo prime.
// assumes x < 2*prime,
// guarantees x < 4*prime on exit.
void bn_mult_3_2(bignum256 * x, const bignum256 *prime)
{
int j;
uint32_t xodd = -(x->val[0] & 1);
// compute x = 3*x/2 mod prime
// if x is odd compute (3*x+prime)/2
uint32_t tmp1 = (3*x->val[0] + (prime->val[0] & xodd)) >> 1;
for (j = 0; j < 8; j++) {
uint32_t tmp2 = (3*x->val[j+1] + (prime->val[j+1] & xodd));
tmp1 += (tmp2 & 1) << 29;
x->val[j] = tmp1 & 0x3fffffff;
tmp1 >>= 30;
tmp1 += tmp2 >> 1;
}
x->val[8] = tmp1;
}
// assumes x < 2*prime, result < prime
void bn_mod(bignum256 *x, const bignum256 *prime)
{

View File

@ -57,6 +57,8 @@ void bn_lshift(bignum256 *a);
void bn_rshift(bignum256 *a);
void bn_mult_3_2(bignum256 *x, const bignum256 *prime);
void bn_mod(bignum256 *x, const bignum256 *prime);
void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime);

73
ecdsa.c
View File

@ -37,8 +37,7 @@
// Set cp2 = cp1
void point_copy(const curve_point *cp1, curve_point *cp2)
{
memcpy(&(cp2->x), &(cp1->x), sizeof(bignum256));
memcpy(&(cp2->y), &(cp1->y), sizeof(bignum256));
*cp2 = *cp1;
}
// cp2 = cp1 + cp2
@ -68,7 +67,9 @@ void point_add(const curve_point *cp1, curve_point *cp2)
bn_inverse(&inv, &prime256k1);
bn_subtractmod(&(cp2->y), &(cp1->y), &lambda, &prime256k1);
bn_multiply(&inv, &lambda, &prime256k1);
memcpy(&xr, &lambda, sizeof(bignum256));
// xr = lambda^2 - x1 - x2
xr = lambda;
bn_multiply(&xr, &xr, &prime256k1);
temp = 1;
for (i = 0; i < 9; i++) {
@ -77,16 +78,17 @@ void point_add(const curve_point *cp1, curve_point *cp2)
temp >>= 30;
}
bn_fast_mod(&xr, &prime256k1);
bn_mod(&xr, &prime256k1);
// yr = lambda (x1 - xr) - y1
bn_subtractmod(&(cp1->x), &xr, &yr, &prime256k1);
// no need to fast_mod here
// bn_fast_mod(&yr);
bn_multiply(&lambda, &yr, &prime256k1);
bn_subtractmod(&yr, &(cp1->y), &yr, &prime256k1);
bn_fast_mod(&yr, &prime256k1);
memcpy(&(cp2->x), &xr, sizeof(bignum256));
memcpy(&(cp2->y), &yr, sizeof(bignum256));
bn_mod(&(cp2->x), &prime256k1);
bn_mod(&(cp2->y), &prime256k1);
bn_mod(&yr, &prime256k1);
cp2->x = xr;
cp2->y = yr;
}
// cp = cp + cp
@ -94,7 +96,7 @@ void point_double(curve_point *cp)
{
int i;
uint32_t temp;
bignum256 lambda, inverse_y, xr, yr;
bignum256 lambda, xr, yr;
if (point_is_infinity(cp)) {
return;
@ -104,13 +106,15 @@ void point_double(curve_point *cp)
return;
}
memcpy(&inverse_y, &(cp->y), sizeof(bignum256));
bn_inverse(&inverse_y, &prime256k1);
memcpy(&lambda, &three_over_two256k1, sizeof(bignum256));
bn_multiply(&inverse_y, &lambda, &prime256k1);
bn_multiply(&(cp->x), &lambda, &prime256k1);
bn_multiply(&(cp->x), &lambda, &prime256k1);
memcpy(&xr, &lambda, sizeof(bignum256));
// lambda = 3/2 x^2 / y
lambda = cp->y;
bn_inverse(&lambda, &prime256k1);
bn_multiply(&cp->x, &lambda, &prime256k1);
bn_multiply(&cp->x, &lambda, &prime256k1);
bn_mult_3_2(&lambda, &prime256k1);
// xr = lambda^2 - 2*x
xr = lambda;
bn_multiply(&xr, &xr, &prime256k1);
temp = 1;
for (i = 0; i < 9; i++) {
@ -119,16 +123,17 @@ void point_double(curve_point *cp)
temp >>= 30;
}
bn_fast_mod(&xr, &prime256k1);
bn_mod(&xr, &prime256k1);
// yr = lambda (x - xr) - y
bn_subtractmod(&(cp->x), &xr, &yr, &prime256k1);
// no need to fast_mod here
// bn_fast_mod(&yr);
bn_multiply(&lambda, &yr, &prime256k1);
bn_subtractmod(&yr, &(cp->y), &yr, &prime256k1);
bn_fast_mod(&yr, &prime256k1);
memcpy(&(cp->x), &xr, sizeof(bignum256));
memcpy(&(cp->y), &yr, sizeof(bignum256));
bn_mod(&(cp->x), &prime256k1);
bn_mod(&(cp->y), &prime256k1);
bn_mod(&yr, &prime256k1);
cp->x = xr;
cp->y = yr;
}
// set point to internal representation of point at infinity
@ -322,8 +327,7 @@ static void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2)
static void point_jacobian_double(jacobian_curve_point *p) {
bignum256 m, msq, ysq, xysq;
int j;
uint32_t tmp1, tmp2;
uint32_t modd;
uint32_t tmp1;
/* usual algorithm:
*
@ -336,7 +340,7 @@ static void point_jacobian_double(jacobian_curve_point *p) {
* Hence,
* lambda = m / yz
*
* With z3 = 2yz (the denominator of lambda)
* With z3 = yz (the denominator of lambda)
* we get x3 = lambda^2*z3^2 - 2*x/z^2*z3^2
* = m^2 - 2*xy^2
* and y3 = (lambda * (x/z^2 - x3/z3^2) - y/z^3) * z3^3
@ -352,18 +356,7 @@ static void point_jacobian_double(jacobian_curve_point *p) {
m = p->x;
bn_multiply(&m, &m, &prime256k1);
modd = -(m.val[0] & 1);
// compute m = 3*m/2 mod prime
// if m is odd compute (3*m+prime)/2
tmp1 = (3*m.val[0] + (prime256k1.val[0] & modd)) >> 1;
for (j = 0; j < 8; j++) {
tmp2 = (3*m.val[j+1] + (prime256k1.val[j+1] & modd));
tmp1 += (tmp2 & 1) << 29;
m.val[j] = tmp1 & 0x3fffffff;
tmp1 >>= 30;
tmp1 += tmp2 >> 1;
}
m.val[8] = tmp1;
bn_mult_3_2(&m, &prime256k1);
// msq = m^2
msq = m;
@ -374,6 +367,8 @@ static void point_jacobian_double(jacobian_curve_point *p) {
// xysq = xy^2
xysq = p->x;
bn_multiply(&ysq, &xysq, &prime256k1);
// z3 = yz
bn_multiply(&p->y, &p->z, &prime256k1);
bn_mod(&p->z, &prime256k1);
@ -387,7 +382,7 @@ static void point_jacobian_double(jacobian_curve_point *p) {
bn_fast_mod(&p->x, &prime256k1);
bn_mod(&p->x, &prime256k1);
// y = m*(xy^2 - x3) - y^4
// y3 = m*(xy^2 - x3) - y^4
bn_subtractmod(&xysq, &p->x, &p->y, &prime256k1);
bn_multiply(&m, &p->y, &prime256k1);
bn_multiply(&ysq, &ysq, &prime256k1);