1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-05 04:50:57 +00:00

bignum: use constant time comparisons

This commit is contained in:
Roman Zeyde 2015-08-02 22:01:49 +03:00
parent bfa812441d
commit 793234a0ec
2 changed files with 53 additions and 7 deletions

View File

@ -88,28 +88,32 @@ void bn_zero(bignum256 *a)
int bn_is_zero(const bignum256 *a) int bn_is_zero(const bignum256 *a)
{ {
int i; int i;
uint32_t result = 0;
for (i = 0; i < 9; i++) { for (i = 0; i < 9; i++) {
if (a->val[i] != 0) return 0; result |= a->val[i];
} }
return 1; return !result;
} }
int bn_is_less(const bignum256 *a, const bignum256 *b) int bn_is_less(const bignum256 *a, const bignum256 *b)
{ {
int i; int i;
uint32_t res1 = 0;
uint32_t res2 = 0;
for (i = 8; i >= 0; i--) { for (i = 8; i >= 0; i--) {
if (a->val[i] < b->val[i]) return 1; res1 = (res1 << 1) | (a->val[i] < b->val[i]);
if (a->val[i] > b->val[i]) return 0; res2 = (res2 << 1) | (a->val[i] > b->val[i]);
} }
return 0; return res1 > res2;
} }
int bn_is_equal(const bignum256 *a, const bignum256 *b) { int bn_is_equal(const bignum256 *a, const bignum256 *b) {
int i; int i;
uint32_t result = 0;
for (i = 0; i < 9; i++) { for (i = 0; i < 9; i++) {
if (a->val[i] != b->val[i]) return 0; result |= (a->val[i] ^ b->val[i]);
} }
return 1; return !result;
} }
int bn_bitlen(const bignum256 *a) { int bn_bitlen(const bignum256 *a) {

View File

@ -97,6 +97,48 @@ def test_inverse(curve, r):
assert y == y_ assert y == y_
def test_is_less(curve, r):
x = r.randrange(0, curve.p)
y = r.randrange(0, curve.p)
x_ = int2bn(x)
y_ = int2bn(y)
res = lib.bn_is_less(x_, y_)
assert res == (x < y)
res = lib.bn_is_less(y_, x_)
assert res == (y < x)
def test_is_equal(curve, r):
x = r.randrange(0, curve.p)
y = r.randrange(0, curve.p)
x_ = int2bn(x)
y_ = int2bn(y)
assert lib.bn_is_equal(x_, y_) == (x == y)
assert lib.bn_is_equal(x_, x_) == 1
assert lib.bn_is_equal(y_, y_) == 1
def test_is_zero(curve, r):
x = r.randrange(0, curve.p);
assert lib.bn_is_zero(int2bn(x)) == (not x)
def test_simple_comparisons():
assert lib.bn_is_zero(int2bn(0)) == 1
assert lib.bn_is_zero(int2bn(1)) == 0
assert lib.bn_is_less(int2bn(0), int2bn(0)) == 0
assert lib.bn_is_less(int2bn(1), int2bn(0)) == 0
assert lib.bn_is_less(int2bn(0), int2bn(1)) == 1
assert lib.bn_is_equal(int2bn(0), int2bn(0)) == 1
assert lib.bn_is_equal(int2bn(1), int2bn(0)) == 0
assert lib.bn_is_equal(int2bn(0), int2bn(1)) == 0
def test_mult_half(curve, r): def test_mult_half(curve, r):
x = r.randrange(0, 2*curve.p) x = r.randrange(0, 2*curve.p)
y = int2bn(x) y = int2bn(x)