1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 04:18:10 +00:00

test(core): fix secp256k1 unit tests

This commit is contained in:
Ondřej Vejpustek 2021-11-05 16:52:59 +01:00
parent 3419961797
commit 046beb4fde

View File

@ -3,16 +3,8 @@ from common import *
from trezor.crypto import random from trezor.crypto import random
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
if not utils.BITCOIN_ONLY:
try:
from trezor.crypto.curve import secp256k1_zkp
except ImportError:
secp256k1_zkp = None
class Secp256k1Common(object):
impl = None
class TestCryptoSecp256k1(unittest.TestCase):
# vectors from https://crypto.stackexchange.com/questions/784/are-there-any-secp256k1-ecdsa-test-examples-available # vectors from https://crypto.stackexchange.com/questions/784/are-there-any-secp256k1-ecdsa-test-examples-available
vectors = [ vectors = [
(1, '79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8'), (1, '79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8'),
@ -64,7 +56,7 @@ class Secp256k1Common(object):
def test_generate_secret(self): def test_generate_secret(self):
for _ in range(100): for _ in range(100):
sk = self.impl.generate_secret() sk = secp256k1.generate_secret()
self.assertTrue(len(sk) == 32) self.assertTrue(len(sk) == 32)
self.assertTrue(sk != b'\x00' * 32) self.assertTrue(sk != b'\x00' * 32)
self.assertTrue(sk < b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xBA\xAE\xDC\xE6\xAF\x48\xA0\x3B\xBF\xD2\x5E\x8C\xD0\x36\x41\x41') self.assertTrue(sk < b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xBA\xAE\xDC\xE6\xAF\x48\xA0\x3B\xBF\xD2\x5E\x8C\xD0\x36\x41\x41')
@ -75,55 +67,55 @@ class Secp256k1Common(object):
if len(sk) < 64: if len(sk) < 64:
sk = '0' * (64 - len(sk)) + sk sk = '0' * (64 - len(sk)) + sk
pk = pk.lower() pk = pk.lower()
pk65 = hexlify(self.impl.publickey(unhexlify(sk), False)).decode() # uncompressed pk65 = hexlify(secp256k1.publickey(unhexlify(sk), False)).decode() # uncompressed
self.assertEqual(str(pk65), '04' + pk) self.assertEqual(str(pk65), '04' + pk)
pk33 = hexlify(self.impl.publickey(unhexlify(sk))).decode() pk33 = hexlify(secp256k1.publickey(unhexlify(sk))).decode()
if pk[-1] in '02468ace': if pk[-1] in '02468ace':
self.assertEqual(pk33, '02' + pk[:64]) self.assertEqual(pk33, '02' + pk[:64])
else: else:
self.assertEqual(pk33, '03' + pk[:64]) self.assertEqual(pk33, '03' + pk[:64])
def test_sign_verify_min_max(self): def test_sign_verify_min_max(self):
sk = self.impl.generate_secret() sk = secp256k1.generate_secret()
pk = self.impl.publickey(sk) pk = secp256k1.publickey(sk)
dig = bytes([1] + [0] * 31) dig = bytes([1] + [0] * 31)
sig = self.impl.sign(sk, dig) sig = secp256k1.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig)) self.assertTrue(secp256k1.verify(pk, sig, dig))
dig = bytes([0] * 31 + [1]) dig = bytes([0] * 31 + [1])
sig = self.impl.sign(sk, dig) sig = secp256k1.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig)) self.assertTrue(secp256k1.verify(pk, sig, dig))
dig = bytes([0xFF] * 32) dig = bytes([0xFF] * 32)
sig = self.impl.sign(sk, dig) sig = secp256k1.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig)) self.assertTrue(secp256k1.verify(pk, sig, dig))
def test_sign_verify_random(self): def test_sign_verify_random(self):
for _ in range(100): for _ in range(100):
sk = self.impl.generate_secret() sk = secp256k1.generate_secret()
pk = self.impl.publickey(sk) pk = secp256k1.publickey(sk)
dig = random.bytes(32) dig = random.bytes(32)
sig = self.impl.sign(sk, dig) sig = secp256k1.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig)) self.assertTrue(secp256k1.verify(pk, sig, dig))
def test_verify_recover(self): def test_verify_recover(self):
for compressed in [False, True]: for compressed in [False, True]:
for _ in range(100): for _ in range(100):
sk = self.impl.generate_secret() sk = secp256k1.generate_secret()
pk = self.impl.publickey(sk, compressed) pk = secp256k1.publickey(sk, compressed)
dig = random.bytes(32) dig = random.bytes(32)
sig = self.impl.sign(sk, dig, compressed) sig = secp256k1.sign(sk, dig, compressed)
pk2 = self.impl.verify_recover(sig, dig) pk2 = secp256k1.verify_recover(sig, dig)
self.assertEqual(pk, pk2) self.assertEqual(pk, pk2)
def test_ecdh(self): def test_ecdh(self):
for _ in range(100): for _ in range(100):
sk1 = self.impl.generate_secret() sk1 = secp256k1.generate_secret()
pk1 = self.impl.publickey(sk1, False) pk1 = secp256k1.publickey(sk1, False)
sk2 = self.impl.generate_secret() sk2 = secp256k1.generate_secret()
pk2 = self.impl.publickey(sk2, True) pk2 = secp256k1.publickey(sk2, True)
self.assertEqual(self.impl.multiply(sk1, pk2), self.impl.multiply(sk2, pk1)) self.assertEqual(secp256k1.multiply(sk1, pk2), secp256k1.multiply(sk2, pk1))
(sk, pk) = self.vectors[0] (sk, pk) = self.vectors[0]
sk = hex(sk)[2:] sk = hex(sk)[2:]
@ -131,23 +123,14 @@ class Secp256k1Common(object):
sk = '0' * (64 - len(sk)) + sk sk = '0' * (64 - len(sk)) + sk
sk = unhexlify(sk) sk = unhexlify(sk)
pk = pk.lower() pk = pk.lower()
pk33 = self.impl.publickey(sk) pk33 = secp256k1.publickey(sk)
pk65 = self.impl.publickey(sk, False) pk65 = secp256k1.publickey(sk, False)
fixed_vector_hex = b"0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8" fixed_vector_hex = b"0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"
fixed_vector1 = self.impl.multiply(sk, pk65) fixed_vector1 = secp256k1.multiply(sk, pk65)
fixed_vector2 = self.impl.multiply(sk, pk33) fixed_vector2 = secp256k1.multiply(sk, pk33)
self.assertEqual(fixed_vector1, fixed_vector2) self.assertEqual(fixed_vector1, fixed_vector2)
self.assertEqual(hexlify(fixed_vector1), fixed_vector_hex) self.assertEqual(hexlify(fixed_vector1), fixed_vector_hex)
class TestCryptoSecp256k1(Secp256k1Common, unittest.TestCase):
def __init__(self):
self.impl = secp256k1
@unittest.skipUnless(secp256k1_zkp is not None, "altcoin")
class TestCryptoSecp256k1Zkp(Secp256k1Common, unittest.TestCase):
def __init__(self):
self.impl = secp256k1_zkp.Context()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()