test(core): fix secp256k1 unit tests

pull/1953/head
Ondřej Vejpustek 3 years ago
parent 3419961797
commit 046beb4fde

@ -3,16 +3,8 @@ from common import *
from trezor.crypto import random
from trezor.crypto.curve import secp256k1
if not utils.BITCOIN_ONLY:
try:
from trezor.crypto.curve import secp256k1_zkp
except ImportError:
secp256k1_zkp = None
class Secp256k1Common(object):
impl = None
class TestCryptoSecp256k1(unittest.TestCase):
# vectors from https://crypto.stackexchange.com/questions/784/are-there-any-secp256k1-ecdsa-test-examples-available
vectors = [
(1, '79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8'),
@ -64,7 +56,7 @@ class Secp256k1Common(object):
def test_generate_secret(self):
for _ in range(100):
sk = self.impl.generate_secret()
sk = secp256k1.generate_secret()
self.assertTrue(len(sk) == 32)
self.assertTrue(sk != b'\x00' * 32)
self.assertTrue(sk < b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE\xBA\xAE\xDC\xE6\xAF\x48\xA0\x3B\xBF\xD2\x5E\x8C\xD0\x36\x41\x41')
@ -75,55 +67,55 @@ class Secp256k1Common(object):
if len(sk) < 64:
sk = '0' * (64 - len(sk)) + sk
pk = pk.lower()
pk65 = hexlify(self.impl.publickey(unhexlify(sk), False)).decode() # uncompressed
pk65 = hexlify(secp256k1.publickey(unhexlify(sk), False)).decode() # uncompressed
self.assertEqual(str(pk65), '04' + pk)
pk33 = hexlify(self.impl.publickey(unhexlify(sk))).decode()
pk33 = hexlify(secp256k1.publickey(unhexlify(sk))).decode()
if pk[-1] in '02468ace':
self.assertEqual(pk33, '02' + pk[:64])
else:
self.assertEqual(pk33, '03' + pk[:64])
def test_sign_verify_min_max(self):
sk = self.impl.generate_secret()
pk = self.impl.publickey(sk)
sk = secp256k1.generate_secret()
pk = secp256k1.publickey(sk)
dig = bytes([1] + [0] * 31)
sig = self.impl.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig))
sig = secp256k1.sign(sk, dig)
self.assertTrue(secp256k1.verify(pk, sig, dig))
dig = bytes([0] * 31 + [1])
sig = self.impl.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig))
sig = secp256k1.sign(sk, dig)
self.assertTrue(secp256k1.verify(pk, sig, dig))
dig = bytes([0xFF] * 32)
sig = self.impl.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig))
sig = secp256k1.sign(sk, dig)
self.assertTrue(secp256k1.verify(pk, sig, dig))
def test_sign_verify_random(self):
for _ in range(100):
sk = self.impl.generate_secret()
pk = self.impl.publickey(sk)
sk = secp256k1.generate_secret()
pk = secp256k1.publickey(sk)
dig = random.bytes(32)
sig = self.impl.sign(sk, dig)
self.assertTrue(self.impl.verify(pk, sig, dig))
sig = secp256k1.sign(sk, dig)
self.assertTrue(secp256k1.verify(pk, sig, dig))
def test_verify_recover(self):
for compressed in [False, True]:
for _ in range(100):
sk = self.impl.generate_secret()
pk = self.impl.publickey(sk, compressed)
sk = secp256k1.generate_secret()
pk = secp256k1.publickey(sk, compressed)
dig = random.bytes(32)
sig = self.impl.sign(sk, dig, compressed)
pk2 = self.impl.verify_recover(sig, dig)
sig = secp256k1.sign(sk, dig, compressed)
pk2 = secp256k1.verify_recover(sig, dig)
self.assertEqual(pk, pk2)
def test_ecdh(self):
for _ in range(100):
sk1 = self.impl.generate_secret()
pk1 = self.impl.publickey(sk1, False)
sk2 = self.impl.generate_secret()
pk2 = self.impl.publickey(sk2, True)
self.assertEqual(self.impl.multiply(sk1, pk2), self.impl.multiply(sk2, pk1))
sk1 = secp256k1.generate_secret()
pk1 = secp256k1.publickey(sk1, False)
sk2 = secp256k1.generate_secret()
pk2 = secp256k1.publickey(sk2, True)
self.assertEqual(secp256k1.multiply(sk1, pk2), secp256k1.multiply(sk2, pk1))
(sk, pk) = self.vectors[0]
sk = hex(sk)[2:]
@ -131,23 +123,14 @@ class Secp256k1Common(object):
sk = '0' * (64 - len(sk)) + sk
sk = unhexlify(sk)
pk = pk.lower()
pk33 = self.impl.publickey(sk)
pk65 = self.impl.publickey(sk, False)
pk33 = secp256k1.publickey(sk)
pk65 = secp256k1.publickey(sk, False)
fixed_vector_hex = b"0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"
fixed_vector1 = self.impl.multiply(sk, pk65)
fixed_vector2 = self.impl.multiply(sk, pk33)
fixed_vector1 = secp256k1.multiply(sk, pk65)
fixed_vector2 = secp256k1.multiply(sk, pk33)
self.assertEqual(fixed_vector1, fixed_vector2)
self.assertEqual(hexlify(fixed_vector1), fixed_vector_hex)
class TestCryptoSecp256k1(Secp256k1Common, unittest.TestCase):
def __init__(self):
self.impl = secp256k1
@unittest.skipUnless(secp256k1_zkp is not None, "altcoin")
class TestCryptoSecp256k1Zkp(Secp256k1Common, unittest.TestCase):
def __init__(self):
self.impl = secp256k1_zkp.Context()
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save