1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 06:18:07 +00:00

Extended comments, new function bn_add, a bug fix.

Describe normalized, partly reduced and reduced numbers.
Comment which function expects which kind of input.
Removed unused bn_bitlen.
Add bn_add that does not reduce.
Bug fix in ecdsa_validate_pubkey: bn_mod before bn_is_equal.
Bug fix in hdnode_private_ckd: bn_mod after bn_addmod.
This commit is contained in:
Jochen Hoenicke 2015-08-06 18:37:55 +02:00
parent 53fa580b81
commit f93b003cbc
4 changed files with 120 additions and 57 deletions

150
bignum.c
View File

@ -27,6 +27,31 @@
#include "bignum.h"
#include "macros.h"
/* big number library */
/* The structure bignum256 is an array of nine 32-bit values, which
* are digits in base 2^30 representation. I.e. the number
* bignum256 a;
* represents the value
* sum_{i=0}^8 a.val[i] * 2^{30 i}.
*
* The number is *normalized* iff every digit is < 2^30.
*
* As the name suggests, a bignum256 is intended to represent a 256
* bit number, but it can represent 270 bits. Numbers are usually
* reduced using a prime, either the group order or the field prime.
* The reduction is often partly done by bn_fast_mod, and similarly
* implicitly in bn_multiply. A *partly reduced number* is a
* normalized number between 0 (inclusive) and 2*prime (exclusive).
*
* A partly reduced number can be fully reduced by calling bn_mod.
* Only a fully reduced number is guaranteed to fit in 256 bit.
*
* All functions assume that the prime in question is slightly smaller
* than 2^256. In particular it must be between 2^256-2^224 and
* 2^256 and it must be a prime number.
*/
inline uint32_t read_be(const uint8_t *data)
{
return (((uint32_t)data[0]) << 24) |
@ -43,7 +68,8 @@ inline void write_be(uint8_t *data, uint32_t x)
data[3] = x;
}
// convert a raw bigendian 256 bit number to a normalized bignum
// convert a raw bigendian 256 bit value into a normalized bignum.
// out_number is partly reduced (since it fits in 256 bit).
void bn_read_be(const uint8_t *in_number, bignum256 *out_number)
{
int i;
@ -63,7 +89,7 @@ void bn_read_be(const uint8_t *in_number, bignum256 *out_number)
}
// convert a normalized bignum to a raw bigendian 256 bit number.
// in_number must be normalized and < 2^256.
// in_number must be fully reduced.
void bn_write_be(const bignum256 *in_number, uint8_t *out_number)
{
int i;
@ -77,6 +103,7 @@ void bn_write_be(const bignum256 *in_number, uint8_t *out_number)
}
}
// sets a bignum to zero.
void bn_zero(bignum256 *a)
{
int i;
@ -85,6 +112,9 @@ void bn_zero(bignum256 *a)
}
}
// checks that a bignum is zero.
// a must be normalized
// function is constant time (on some architectures, in particular ARM).
int bn_is_zero(const bignum256 *a)
{
int i;
@ -95,6 +125,9 @@ int bn_is_zero(const bignum256 *a)
return !result;
}
// Check whether a < b
// a and b must be normalized
// function is constant time (on some architectures, in particular ARM).
int bn_is_less(const bignum256 *a, const bignum256 *b)
{
int i;
@ -107,6 +140,9 @@ int bn_is_less(const bignum256 *a, const bignum256 *b)
return res1 > res2;
}
// Check whether a == b
// a and b must be normalized
// function is constant time (on some architectures, in particular ARM).
int bn_is_equal(const bignum256 *a, const bignum256 *b) {
int i;
uint32_t result = 0;
@ -116,6 +152,9 @@ int bn_is_equal(const bignum256 *a, const bignum256 *b) {
return !result;
}
// Assigns res = cond ? truecase : falsecase
// assumes that cond is either 0 or 1.
// function is constant time.
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase)
{
int i;
@ -129,15 +168,8 @@ void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum25
}
}
int bn_bitlen(const bignum256 *a) {
int i = 8, j;
while (i >= 0 && a->val[i] == 0) i--;
if (i == -1) return 0;
j = 29;
while ((a->val[i] & (1 << j)) == 0) j--;
return i * 30 + j + 1;
}
// shift number to the left, i.e multiply it by 2.
// a must be normalized. The result is normalized but not reduced.
void bn_lshift(bignum256 *a)
{
int i;
@ -147,6 +179,8 @@ void bn_lshift(bignum256 *a)
a->val[0] = (a->val[0] << 1) & 0x3FFFFFFF;
}
// shift number to the right, i.e divide by 2 while rounding down.
// a must be normalized. The result is normalized.
void bn_rshift(bignum256 *a)
{
int i;
@ -157,8 +191,10 @@ void bn_rshift(bignum256 *a)
}
// multiply x by 1/2 modulo prime.
// assumes x < 2*prime,
// guarantees x < 4*prime on exit.
// it computes x = (x & 1) ? (x + prime) >> 1 : x >> 1.
// assumes x is normalized.
// if x was partly reduced, it is also partly reduced on exit.
// function is constant time.
void bn_mult_half(bignum256 * x, const bignum256 *prime)
{
int j;
@ -177,8 +213,8 @@ void bn_mult_half(bignum256 * x, const bignum256 *prime)
}
// multiply x by k modulo prime.
// assumes x < prime,
// guarantees x < prime on exit.
// assumes x is normalized, 0 <= k <= 4.
// guarantees x is partly reduced.
void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime)
{
int j;
@ -188,7 +224,8 @@ void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime)
bn_fast_mod(x, prime);
}
// assumes x < 2*prime, result < prime
// compute x = x mod prime by computing x >= prime ? x - prime : x.
// assumes x partly reduced, guarantees x fully reduced.
void bn_mod(bignum256 *x, const bignum256 *prime)
{
int i = 8;
@ -214,6 +251,9 @@ void bn_mod(bignum256 *x, const bignum256 *prime)
}
}
// auxiliary function for multiplication.
// compute k * x as a 540 bit number in base 2^30 (normalized).
// assumes that k and x are normalized.
void bn_multiply_long(const bignum256 *k, const bignum256 *x, uint32_t res[18])
{
int i, j;
@ -223,6 +263,7 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x, uint32_t res[18])
for (i = 0; i < 9; i++)
{
for (j = 0; j <= i; j++) {
// no overflow, since 9*2^60 < 2^64
temp += k->val[j] * (uint64_t)x->val[i - j];
}
res[i] = temp & 0x3FFFFFFFu;
@ -232,6 +273,7 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x, uint32_t res[18])
for (; i < 17; i++)
{
for (j = i - 8; j < 9 ; j++) {
// no overflow, since 9*2^60 < 2^64
temp += k->val[j] * (uint64_t)x->val[i - j];
}
res[i] = temp & 0x3FFFFFFFu;
@ -240,13 +282,16 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x, uint32_t res[18])
res[17] = temp;
}
// auxiliary function for multiplication.
// reduces res modulo prime.
// assumes res normalized, res < 2^(30(i-7)) * 2 * prime
// guarantees res normalized, res < 2^(30(i-8)) * 2 * prime
void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t i) {
// let k = i-8.
// invariants:
// res[0..(i+1)] = k * x (mod prime)
// 0 <= res < 2^(30k + 256) * (2^31)
// estimate (res / prime)
// coef = res / 2^(30k + 256) rounded down
// on entry:
// 0 <= res < 2^(30k + 31) * prime
// estimate coef = (res / prime / 2^30k)
// by coef = res / 2^(30k + 256) rounded down
// 0 <= coef < 2^31
// subtract (coef * 2^(30k) * prime) from res
// note that we unrolled the first iteration
@ -257,11 +302,11 @@ void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t
res[i - 8] = temp & 0x3FFFFFFF;
for (j = 1; j < 9; j++) {
temp >>= 30;
// Note: coeff * prime->val <= (2^31-1) * (2^30-1)
// Note: coeff * prime->val[j] <= (2^31-1) * (2^30-1)
// Hence, this addition will not underflow.
temp += 0x1FFFFFFF80000000ull + res[i - 8 + j] - prime->val[j] * (uint64_t)coef;
res[i - 8 + j] = temp & 0x3FFFFFFF;
// 0 <= temp < 2^61
// 0 <= temp < 2^61 + 2^30
}
temp >>= 30;
temp += 0x1FFFFFFF80000000ull + res[i - 8 + j];
@ -271,18 +316,20 @@ void bn_multiply_reduce_step(uint32_t res[18], const bignum256 *prime, uint32_t
// and
// coef * 2^(30k + 256) <= oldres < (coef+1) * 2^(30k + 256)
// Hence, 0 <= res < 2^30k (2^256 + coef * (2^256 - prime))
// Since coef * (2^256 - prime) < 2^256, we get
// 0 <= res < 2^(30k + 226) (2^31)
// Thus the invariant holds again.
// < 2^30k (2^256 + 2^31 * 2^224)
// < 2^30k (2 * prime)
}
// auxiliary function for multiplication.
// reduces x = res modulo prime.
// assumes res normalized, res < 2^270 * 2 * prime
// guarantees x partly reduced, i.e., x < 2 * prime
void bn_multiply_reduce(bignum256 *x, uint32_t res[18], const bignum256 *prime)
{
int i;
// res = k * x is a normalized number (every limb < 2^30)
// 0 <= res < 2^526.
// compute modulo p division is only estimated so this may give result greater than prime but not bigger than 2 * prime
// 0 <= res < 2^270 * 2 * prime.
for (i = 16; i >= 8; i--) {
bn_multiply_reduce_step(res, prime, i);
assert(res[i + 1] == 0);
@ -294,11 +341,9 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[18], const bignum256 *prime)
}
// Compute x := k * x (mod prime)
// both inputs must be smaller than 2 * prime.
// result is reduced to 0 <= x < 2 * prime
// This only works for primes between 2^256-2^196 and 2^256.
// this particular implementation accepts inputs up to 2^263 or 128*prime.
// both inputs must be smaller than 180 * prime.
// result is partly reduced (0 <= x < 2 * prime)
// This only works for primes between 2^256-2^224 and 2^256.
void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime)
{
uint32_t res[18] = {0};
@ -307,9 +352,11 @@ void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime)
MEMSET_BZERO(res, sizeof(res));
}
// input x can be any normalized number that fits (0 <= x < 2^270).
// partly reduce x modulo prime
// input x does not have to be normalized.
// x can be any number that fits.
// prime must be between (2^256 - 2^224) and 2^256
// result is smaller than 2*prime
// result is partly reduced, smaller than 2*prime
void bn_fast_mod(bignum256 *x, const bignum256 *prime)
{
int j;
@ -330,6 +377,8 @@ void bn_fast_mod(bignum256 *x, const bignum256 *prime)
// square root of x = x^((p+1)/4)
// http://en.wikipedia.org/wiki/Quadratic_residue#Prime_or_prime_power_modulus
// assumes x is normalized but not necessarily reduced.
// guarantees x is reduced
void bn_sqrt(bignum256 *x, const bignum256 *prime)
{
// this method compute x^1/2 = x^(prime+1)/4
@ -678,10 +727,18 @@ void bn_inverse(bignum256 *x, const bignum256 *prime)
#endif
void bn_normalize(bignum256 *a) {
bn_addi(a, 0);
}
// add two numbers a = a + b
// assumes that a, b are normalized
// guarantees that a is normalized
void bn_add(bignum256 *a, const bignum256 *b)
{
int i;
uint32_t tmp = 0;
for (i = 0; i < 9; i++) {
tmp += a->val[i];
tmp += a->val[i] + b->val[i];
a->val[i] = tmp & 0x3FFFFFFF;
tmp >>= 30;
}
@ -697,22 +754,25 @@ void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime)
}
void bn_addi(bignum256 *a, uint32_t b) {
a->val[0] += b;
bn_normalize(a);
int i;
uint32_t tmp = b;
for (i = 0; i < 9; i++) {
tmp += a->val[i];
a->val[i] = tmp & 0x3FFFFFFF;
tmp >>= 30;
}
}
void bn_subi(bignum256 *a, uint32_t b, const bignum256 *prime) {
int i;
for (i = 0; i < 9; i++) {
a->val[i] += prime->val[i];
}
assert (b <= prime->val[0]);
// the possible underflow will be taken care of when adding the prime
a->val[0] -= b;
bn_fast_mod(a, prime);
bn_add(a, prime);
}
// res = a - b mod prime. More exactly res = a + (2*prime - b).
// precondition: 0 <= b < 2*prime, 0 <= a < prime
// res < 3*prime
// b must be a partly reduced number
// result is normalized but not reduced.
void bn_subtractmod(const bignum256 *a, const bignum256 *b, bignum256 *res, const bignum256 *prime)
{
int i;

View File

@ -53,8 +53,6 @@ int bn_is_equal(const bignum256 *a, const bignum256 *b);
void bn_cmov(bignum256 *res, int cond, const bignum256 *truecase, const bignum256 *falsecase);
int bn_bitlen(const bignum256 *a);
void bn_lshift(bignum256 *a);
void bn_rshift(bignum256 *a);
@ -75,6 +73,8 @@ void bn_inverse(bignum256 *x, const bignum256 *prime);
void bn_normalize(bignum256 *a);
void bn_add(bignum256 *a, const bignum256 *b);
void bn_addmod(bignum256 *a, const bignum256 *b, const bignum256 *prime);
void bn_addi(bignum256 *a, uint32_t b);

View File

@ -143,6 +143,7 @@ int hdnode_private_ckd(HDNode *inout, uint32_t i)
}
if (!failed) {
bn_addmod(&a, &b, &default_curve->order);
bn_mod(&a, &default_curve->order);
if (bn_is_zero(&a)) {
failed = true;
}

22
ecdsa.c
View File

@ -287,7 +287,7 @@ void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const e
bn_fast_mod(&h, prime);
// h = x1' - x2;
bn_addmod(&xz, &p2->x, prime);
bn_add(&xz, &p2->x);
// xz = x1' + x2
is_doubling = bn_is_zero(&h) | bn_is_equal(&h, prime);
@ -296,7 +296,7 @@ void point_jacobian_add(const curve_point *p1, jacobian_curve_point *p2, const e
bn_subtractmod(&yz, &p2->y, &r, prime);
// r = y1' - y2;
bn_addmod(&yz, &p2->y, prime);
bn_add(&yz, &p2->y);
// yz = y1' + y2
r2 = p2->x;
@ -347,6 +347,7 @@ void point_jacobian_double(jacobian_curve_point *p, const ecdsa_curve *curve) {
bignum256 az4, m, msq, ysq, xysq;
const bignum256 *prime = &curve->prime;
assert (-3 <= curve->a && curve->a <= 0);
/* usual algorithm:
*
* lambda = (3((x/z^2)^2 + a) / 2y/z^3) = (3x^2 + az^4)/2yz
@ -861,7 +862,7 @@ void uncompress_coords(const ecdsa_curve *curve, uint8_t odd, const bignum256 *x
bn_multiply(x, y, &curve->prime); // y is x^2
bn_subi(y, -curve->a, &curve->prime); // y is x^2 + a
bn_multiply(x, y, &curve->prime); // y is x^3 + ax
bn_addmod(y, &curve->b, &curve->prime); // y is x^3 + ax + b
bn_add(y, &curve->b); // y is x^3 + ax + b
bn_sqrt(y, &curve->prime); // y = sqrt(y)
if ((odd & 0x01) != (y->val[0] & 1)) {
bn_subtract(&curve->prime, y, y); // y = -y
@ -891,7 +892,7 @@ int ecdsa_read_pubkey(const ecdsa_curve *curve, const uint8_t *pub_key, curve_po
int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub)
{
bignum256 y_2, x_3_b;
bignum256 y_2, x3_ax_b;
if (point_is_infinity(pub)) {
return 0;
@ -902,19 +903,20 @@ int ecdsa_validate_pubkey(const ecdsa_curve *curve, const curve_point *pub)
}
memcpy(&y_2, &(pub->y), sizeof(bignum256));
memcpy(&x_3_b, &(pub->x), sizeof(bignum256));
memcpy(&x3_ax_b, &(pub->x), sizeof(bignum256));
// y^2
bn_multiply(&(pub->y), &y_2, &curve->prime);
bn_mod(&y_2, &curve->prime);
// x^3 + ax + b
bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^2
bn_subi(&x_3_b, -curve->a, &curve->prime); // x^2 + a
bn_multiply(&(pub->x), &x_3_b, &curve->prime); // x^3 + ax
bn_addmod(&x_3_b, &curve->b, &curve->prime); // x^3 + ax + b
bn_multiply(&(pub->x), &x3_ax_b, &curve->prime); // x^2
bn_subi(&x3_ax_b, -curve->a, &curve->prime); // x^2 + a
bn_multiply(&(pub->x), &x3_ax_b, &curve->prime); // x^3 + ax
bn_addmod(&x3_ax_b, &curve->b, &curve->prime); // x^3 + ax + b
bn_mod(&x3_ax_b, &curve->prime);
if (!bn_is_equal(&x_3_b, &y_2)) {
if (!bn_is_equal(&x3_ax_b, &y_2)) {
return 0;
}