diff --git a/setup.py b/setup.py index ea76ea7ce..24a639d69 100755 --- a/setup.py +++ b/setup.py @@ -1,37 +1,39 @@ #!/usr/bin/python from distutils.core import setup from distutils.extension import Extension + from Cython.Build import cythonize from Cython.Distutils import build_ext srcs = [ - 'nist256p1', - 'base58', - 'bignum', - 'bip32', - 'ecdsa', - 'curve25519', - 'hmac', - 'rand', - 'ripemd160', - 'secp256k1', - 'sha2', + "nist256p1", + "base58", + "bignum", + "bip32", + "ecdsa", + "curve25519", + "hmac", + "rand", + "ripemd160", + "secp256k1", + "sha2", ] extensions = [ - Extension('TrezorCrypto', - sources = ['TrezorCrypto.pyx', 'c.pxd'] + [ x + '.c' for x in srcs ], - extra_compile_args = [], - ) + Extension( + "TrezorCrypto", + sources=["TrezorCrypto.pyx", "c.pxd"] + [x + ".c" for x in srcs], + extra_compile_args=[], + ) ] setup( - name = 'TrezorCrypto', - version = '0.0.0', - description = 'Cython wrapper around trezor-crypto library', - author = 'Pavol Rusnak', - author_email = 'stick@satoshilabs.com', - url = 'https://github.com/trezor/trezor-crypto', - cmdclass = {'build_ext': build_ext}, - ext_modules = cythonize(extensions), + name="TrezorCrypto", + version="0.0.0", + description="Cython wrapper around trezor-crypto library", + author="Pavol Rusnak", + author_email="stick@satoshilabs.com", + url="https://github.com/trezor/trezor-crypto", + cmdclass={"build_ext": build_ext}, + ext_modules=cythonize(extensions), ) diff --git a/tests/test_curves.py b/tests/test_curves.py index 028ecef6a..a6db9cee6 100755 --- a/tests/test_curves.py +++ b/tests/test_curves.py @@ -1,13 +1,15 @@ #!/usr/bin/py.test -import ctypes as c -import curve25519 -import random -import ecdsa -import hashlib import binascii +import ctypes as c +import hashlib import os +import random + +import curve25519 +import ecdsa import pytest + def bytes2num(s): res = 0 for i, b in enumerate(reversed(bytearray(s))): @@ -15,10 +17,8 @@ def bytes2num(s): return res -curves = { - 'nist256p1': ecdsa.curves.NIST256p, - 'secp256k1': ecdsa.curves.SECP256k1 -} +curves = {"nist256p1": ecdsa.curves.NIST256p, "secp256k1": ecdsa.curves.SECP256k1} + class Point: def __init__(self, name, x, y): @@ -26,30 +26,70 @@ class Point: self.x = x self.y = y + points = [ - Point('secp256k1', 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8), - Point('secp256k1', 0x1, 0x4218f20ae6c646b363db68605822fb14264ca8d2587fdd6fbc750d587e76a7ee), - Point('secp256k1', 0x2, 0x66fbe727b2ba09e09f5a98d70a5efce8424c5fa425bbda1c511f860657b8535e), - Point('secp256k1', 0x1b,0x1adcea1cf831b0ad1653e769d1a229091d0cc68d4b0328691b9caacc76e37c90), - Point('nist256p1', 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296, 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5), - Point('nist256p1', 0x0, 0x66485c780e2f83d72433bd5d84a06bb6541c2af31dae871728bf856a174f93f4), - Point('nist256p1', 0x0, 0x99b7a386f1d07c29dbcc42a27b5f9449abe3d50de25178e8d7407a95e8b06c0b), - Point('nist256p1', 0xaf8bbdfe8cdd5577acbf345b543d28cf402f4e94d3865b97ea0787f2d3aa5d22,0x35802b8b376b995265918b078bc109c21a535176585c40f519aca52d6afc147c), - Point('nist256p1', 0x80000, 0x580610071f440f0dcc14a22e2d5d5afc1224c0cd11a3b4b51b8ecd2224ee1ce2) + Point( + "secp256k1", + 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, + 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, + ), + Point( + "secp256k1", + 0x1, + 0x4218f20ae6c646b363db68605822fb14264ca8d2587fdd6fbc750d587e76a7ee, + ), + Point( + "secp256k1", + 0x2, + 0x66fbe727b2ba09e09f5a98d70a5efce8424c5fa425bbda1c511f860657b8535e, + ), + Point( + "secp256k1", + 0x1b, + 0x1adcea1cf831b0ad1653e769d1a229091d0cc68d4b0328691b9caacc76e37c90, + ), + Point( + "nist256p1", + 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296, + 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5, + ), + Point( + "nist256p1", + 0x0, + 0x66485c780e2f83d72433bd5d84a06bb6541c2af31dae871728bf856a174f93f4, + ), + Point( + "nist256p1", + 0x0, + 0x99b7a386f1d07c29dbcc42a27b5f9449abe3d50de25178e8d7407a95e8b06c0b, + ), + Point( + "nist256p1", + 0xaf8bbdfe8cdd5577acbf345b543d28cf402f4e94d3865b97ea0787f2d3aa5d22, + 0x35802b8b376b995265918b078bc109c21a535176585c40f519aca52d6afc147c, + ), + Point( + "nist256p1", + 0x80000, + 0x580610071f440f0dcc14a22e2d5d5afc1224c0cd11a3b4b51b8ecd2224ee1ce2, + ), ] -random_iters = int(os.environ.get('ITERS', 1)) +random_iters = int(os.environ.get("ITERS", 1)) DIR = os.path.abspath(os.path.dirname(__file__)) -lib = c.cdll.LoadLibrary(os.path.join(DIR, 'libtrezor-crypto.so')) +lib = c.cdll.LoadLibrary(os.path.join(DIR, "libtrezor-crypto.so")) + class curve_info(c.Structure): - _fields_ = [("bip32_name", c.c_char_p), - ("params", c.c_void_p)] + _fields_ = [("bip32_name", c.c_char_p), ("params", c.c_void_p)] + + lib.get_curve_by_name.restype = c.POINTER(curve_info) BIGNUM = c.c_uint32 * 9 + class Random(random.Random): def randbytes(self, n): buf = (c.c_uint8 * n)() @@ -74,36 +114,40 @@ def int2bn(x, bn_type=BIGNUM): def bn2int(b): x = 0 for i in range(len(b)): - x += (b[i] << (30 * i)) + x += b[i] << (30 * i) return x @pytest.fixture(params=range(random_iters)) def r(request): seed = request.param - return Random(seed + int(os.environ.get('SEED', 0))) + return Random(seed + int(os.environ.get("SEED", 0))) @pytest.fixture(params=list(sorted(curves))) def curve(request): name = request.param curve_ptr = lib.get_curve_by_name(bytes(name, "ascii")).contents.params - assert curve_ptr, 'curve {} not found'.format(name) + assert curve_ptr, "curve {} not found".format(name) curve_obj = curves[name] curve_obj.ptr = c.c_void_p(curve_ptr) curve_obj.p = curve_obj.curve.p() # shorthand return curve_obj + @pytest.fixture(params=points) def point(request): name = request.param.curve curve_ptr = lib.get_curve_by_name(bytes(name, "ascii")).contents.params - assert curve_ptr, 'curve {} not found'.format(name) + assert curve_ptr, "curve {} not found".format(name) curve_obj = curves[name] curve_obj.ptr = c.c_void_p(curve_ptr) - curve_obj.p = ecdsa.ellipticcurve.Point(curve_obj.curve, request.param.x, request.param.y) + curve_obj.p = ecdsa.ellipticcurve.Point( + curve_obj.curve, request.param.x, request.param.y + ) return curve_obj + def test_inverse(curve, r): x = r.randrange(1, curve.p) y = int2bn(x) @@ -138,7 +182,7 @@ def test_is_equal(curve, r): def test_is_zero(curve, r): - x = r.randrange(0, curve.p); + x = r.randrange(0, curve.p) assert lib.bn_is_zero(int2bn(x)) == (not x) @@ -156,7 +200,7 @@ def test_simple_comparisons(): def test_mult_half(curve, r): - x = r.randrange(0, 2*curve.p) + x = r.randrange(0, 2 * curve.p) y = int2bn(x) lib.bn_mult_half(y, int2bn(curve.p)) y = bn2int(y) @@ -172,7 +216,7 @@ def test_subtractmod(curve, r): z = int2bn(0) lib.bn_subtractmod(int2bn(x), int2bn(y), z, int2bn(curve.p)) z = bn2int(z) - z_ = x + 2*curve.p - y + z_ = x + 2 * curve.p - y assert z == z_ @@ -219,7 +263,7 @@ def test_multiply(curve, r): p_ = int2bn(curve.p) lib.bn_multiply(k, z_, p_) z_ = bn2int(z_) - assert z_ < 2*curve.p + assert z_ < 2 * curve.p if z_ >= curve.p: z_ = z_ - curve.p assert z_ == z @@ -249,36 +293,51 @@ def test_multiply2(curve, r): def test_fast_mod(curve, r): - x = r.randrange(0, 128*curve.p) + x = r.randrange(0, 128 * curve.p) y = int2bn(x) lib.bn_fast_mod(y, int2bn(curve.p)) y = bn2int(y) - assert y < 2*curve.p + assert y < 2 * curve.p if y >= curve.p: y -= curve.p assert x % curve.p == y def test_mod(curve, r): - x = r.randrange(0, 2*curve.p) + x = r.randrange(0, 2 * curve.p) y = int2bn(x) lib.bn_mod(y, int2bn(curve.p)) assert bn2int(y) == x % curve.p + def test_mod_specific(curve): p = curve.p - for x in [0, 1, 2, p - 2, p - 1, p, p + 1, p + 2, 2*p - 2, 2*p - 1]: + for x in [0, 1, 2, p - 2, p - 1, p, p + 1, p + 2, 2 * p - 2, 2 * p - 1]: y = int2bn(x) lib.bn_mod(y, int2bn(curve.p)) assert bn2int(y) == x % p + POINT = BIGNUM * 2 -to_POINT = lambda p: POINT(int2bn(p.x()), int2bn(p.y())) -from_POINT = lambda p: (bn2int(p[0]), bn2int(p[1])) + + +def to_POINT(p): + return POINT(int2bn(p.x()), int2bn(p.y())) + + +def from_POINT(p): + return lambda p: (bn2int(p[0]), bn2int(p[1])) + JACOBIAN = BIGNUM * 3 -to_JACOBIAN = lambda jp: JACOBIAN(int2bn(jp[0]), int2bn(jp[1]), int2bn(jp[2])) -from_JACOBIAN = lambda p: (bn2int(p[0]), bn2int(p[1]), bn2int(p[2])) + + +def to_JACOBIAN(jp): + return JACOBIAN(int2bn(jp[0]), int2bn(jp[1]), int2bn(jp[2])) + + +def from_JACOBIAN(p): + return (bn2int(p[0]), bn2int(p[1]), bn2int(p[2])) def test_point_multiply(curve, r): @@ -294,7 +353,7 @@ def test_point_multiply(curve, r): def test_point_add(curve, r): p1 = r.randpoint(curve) p2 = r.randpoint(curve) - #print '-' * 80 + # print '-' * 80 q = p1 + p2 q1 = to_POINT(p1) q2 = to_POINT(p2) @@ -332,7 +391,7 @@ def test_cond_negate(curve, r): lib.conditional_negate(0, a, int2bn(curve.p)) assert bn2int(a) == x lib.conditional_negate(-1, a, int2bn(curve.p)) - assert bn2int(a) == 2*curve.p - x + assert bn2int(a) == 2 * curve.p - x def test_jacobian_add(curve, r): @@ -348,6 +407,7 @@ def test_jacobian_add(curve, r): p_ = p1 + p2 assert (p_.x(), p_.y()) == q + def test_jacobian_add_double(curve, r): p1 = r.randpoint(curve) p2 = p1 @@ -361,6 +421,7 @@ def test_jacobian_add_double(curve, r): p_ = p1 + p2 assert (p_.x(), p_.y()) == q + def test_jacobian_double(curve, r): p = r.randpoint(curve) p2 = p.double() @@ -373,6 +434,7 @@ def test_jacobian_double(curve, r): q = from_POINT(q) assert (p2.x(), p2.y()) == q + def sigdecode(sig, _): return map(bytes2num, [sig[:32], sig[32:]]) @@ -385,15 +447,17 @@ def test_sign(curve, r): lib.ecdsa_sign_digest(curve.ptr, priv, digest, sig, c.c_void_p(0), c.c_void_p(0)) exp = bytes2num(priv) - sk = ecdsa.SigningKey.from_secret_exponent(exp, curve, - hashfunc=hashlib.sha256) + sk = ecdsa.SigningKey.from_secret_exponent(exp, curve, hashfunc=hashlib.sha256) vk = sk.get_verifying_key() - sig_ref = sk.sign_digest_deterministic(digest, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string_canonize) + sig_ref = sk.sign_digest_deterministic( + digest, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string_canonize + ) assert binascii.hexlify(sig) == binascii.hexlify(sig_ref) assert vk.verify_digest(sig, digest, sigdecode) + def test_validate_pubkey(curve, r): p = r.randpoint(curve) assert lib.ecdsa_validate_pubkey(curve.ptr, to_POINT(p)) @@ -431,9 +495,13 @@ def test_curve25519_pubkey(r): def test_curve25519_scalarmult_from_gpg(r): - sec = binascii.unhexlify('4a1e76f133afb29dbc7860bcbc16d0e829009cc15c2f81ed26de1179b1d9c938') - pub = binascii.unhexlify('5d6fc75c016e85b17f54e0128a216d5f9229f25bac1ec85cecab8daf48621b31') + sec = binascii.unhexlify( + "4a1e76f133afb29dbc7860bcbc16d0e829009cc15c2f81ed26de1179b1d9c938" + ) + pub = binascii.unhexlify( + "5d6fc75c016e85b17f54e0128a216d5f9229f25bac1ec85cecab8daf48621b31" + ) res = r.randbytes(32) lib.curve25519_scalarmult(res, sec[::-1], pub[::-1]) - expected = 'a93dbdb23e5c99da743e203bd391af79f2b83fb8d0fd6ec813371c71f08f2d4d' + expected = "a93dbdb23e5c99da743e203bd391af79f2b83fb8d0fd6ec813371c71f08f2d4d" assert binascii.hexlify(bytearray(res)) == bytes(expected, "ascii") diff --git a/tests/test_wycheproof.py b/tests/test_wycheproof.py index 0cc8e3ace..08fe1ab9a 100755 --- a/tests/test_wycheproof.py +++ b/tests/test_wycheproof.py @@ -1,46 +1,47 @@ #!/usr/bin/python +import ctypes +import json import os +from binascii import hexlify, unhexlify + +import pytest +from pyasn1.codec.ber.decoder import decode as ber_decode from pyasn1.codec.der.decoder import decode as der_decode from pyasn1.codec.der.encoder import encode as der_encode -from pyasn1.codec.ber.decoder import decode as ber_decode -from pyasn1.type import univ, namedtype -from binascii import unhexlify, hexlify -import json -import ctypes -import pytest +from pyasn1.type import namedtype, univ class EcSignature(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('r', univ.Integer()), - namedtype.NamedType('s', univ.Integer()) + namedtype.NamedType("r", univ.Integer()), + namedtype.NamedType("s", univ.Integer()), ) class EcKeyInfo(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('key_type', univ.ObjectIdentifier()), - namedtype.NamedType('curve_name', univ.ObjectIdentifier()) + namedtype.NamedType("key_type", univ.ObjectIdentifier()), + namedtype.NamedType("curve_name", univ.ObjectIdentifier()), ) class EcPublicKey(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('key_info', EcKeyInfo()), - namedtype.NamedType('public_key', univ.BitString()) + namedtype.NamedType("key_info", EcKeyInfo()), + namedtype.NamedType("public_key", univ.BitString()), ) class EdKeyInfo(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('key_type', univ.ObjectIdentifier()), + namedtype.NamedType("key_type", univ.ObjectIdentifier()) ) class EdPublicKey(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('key_info', EdKeyInfo()), - namedtype.NamedType('public_key', univ.BitString()) + namedtype.NamedType("key_info", EdKeyInfo()), + namedtype.NamedType("public_key", univ.BitString()), ) @@ -73,12 +74,12 @@ def parse_eddsa_signature(signature): def parse_ecdh256_privkey(private_key): if private_key < 0 or private_key.bit_length() > 256: raise ParseError("Not a valid 256 bit ECDH private key") - return private_key.to_bytes(32, byteorder='big') + return private_key.to_bytes(32, byteorder="big") def parse_signed_hex(string): if len(string) % 2 == 1: - string = '0' + string + string = "0" + string number = int(string, 16) if int(string[0], 16) & 8: return -number @@ -111,10 +112,10 @@ def parse_ed_pubkey(public_key): except Exception: raise ParseError("Not a BER encoded Edwards curve public key") - if not public_key['key_info']['key_type'] == univ.ObjectIdentifier('1.3.101.112'): + if not public_key["key_info"]["key_type"] == univ.ObjectIdentifier("1.3.101.112"): raise ParseError("Not a BER encoded Edwards curve public key") - public_key = bytes(public_key['public_key'].asOctets()) + public_key = bytes(public_key["public_key"].asOctets()) return public_key @@ -125,16 +126,20 @@ def parse_ec_pubkey(public_key): except Exception: raise ParseError("Not a BER encoded named elliptic curve public key") - if not public_key['key_info']['key_type'] == univ.ObjectIdentifier('1.2.840.10045.2.1'): + if not public_key["key_info"]["key_type"] == univ.ObjectIdentifier( + "1.2.840.10045.2.1" + ): raise ParseError("Not a BER encoded named elliptic curve public key") - curve_identifier = public_key['key_info']['curve_name'] + curve_identifier = public_key["key_info"]["curve_name"] curve_name = get_curve_name_by_identifier(curve_identifier) if curve_name is None: - raise NotSupported('Unsupported named elliptic curve: {}'.format(curve_identifier)) + raise NotSupported( + "Unsupported named elliptic curve: {}".format(curve_identifier) + ) try: - public_key = bytes(public_key['public_key'].asOctets()) + public_key = bytes(public_key["public_key"].asOctets()) except: raise ParseError("Not a BER encoded named elliptic curve public key") @@ -150,8 +155,8 @@ def parse_ecdsa256_signature(signature): except: raise ParseError("Not a valid DER encoded ECDSA signature") try: - r = int(signature['r']).to_bytes(32, byteorder='big') - s = int(signature['s']).to_bytes(32, byteorder='big') + r = int(signature["r"]).to_bytes(32, byteorder="big") + s = int(signature["s"]).to_bytes(32, byteorder="big") signature = r + s except: raise ParseError("Not a valid DER encoded 256 bit ECDSA signature") @@ -167,29 +172,29 @@ def parse_digest(name): def get_curve_by_name(name): lib.get_curve_by_name.restype = ctypes.c_void_p - curve = lib.get_curve_by_name(bytes(name, 'ascii')) - if curve == None: + curve = lib.get_curve_by_name(bytes(name, "ascii")) + if curve is None: return None curve = ctypes.cast(curve, ctypes.POINTER(curve_info)) return ctypes.c_void_p(curve.contents.params) def parse_curve_name(name): - if name == 'secp256r1': - return 'nist256p1' - elif name == 'secp256k1': - return 'secp256k1' - elif name == 'curve25519': - return 'curve25519' + if name == "secp256r1": + return "nist256p1" + elif name == "secp256k1": + return "secp256k1" + elif name == "curve25519": + return "curve25519" else: return None def get_curve_name_by_identifier(identifier): - if identifier == univ.ObjectIdentifier('1.3.132.0.10'): - return 'secp256k1' - elif identifier == univ.ObjectIdentifier('1.2.840.10045.3.1.7'): - return 'nist256p1' + if identifier == univ.ObjectIdentifier("1.3.132.0.10"): + return "secp256k1" + elif identifier == univ.ObjectIdentifier("1.2.840.10045.3.1.7"): + return "nist256p1" else: return None @@ -218,26 +223,29 @@ def chacha_poly_decrypt(key, iv, associated_data, ciphertext, tag): def add_pkcs_padding(data): padding_length = 16 - len(data) % 16 - return data + bytes([padding_length]*padding_length) + return data + bytes([padding_length] * padding_length) def remove_pkcs_padding(data): padding_length = data[-1] - if not (0 < padding_length <= 16 and data[-padding_length:] == bytes([padding_length]*padding_length)): + if not ( + 0 < padding_length <= 16 + and data[-padding_length:] == bytes([padding_length] * padding_length) + ): return False else: return data[:-padding_length] def aes_encrypt_initialise(key, context): - if len(key) == (128/8): + if len(key) == (128 / 8): lib.aes_encrypt_key128(key, context) - elif len(key) == (192/8): + elif len(key) == (192 / 8): lib.aes_encrypt_key192(key, context) - elif len(key) == (256/8): + elif len(key) == (256 / 8): lib.aes_encrypt_key256(key, context) else: - raise NotSupported("Unsupported key length: {}".format(len(key)*8)) + raise NotSupported("Unsupported key length: {}".format(len(key) * 8)) def aes_cbc_encrypt(key, iv, plaintext): @@ -245,19 +253,21 @@ def aes_cbc_encrypt(key, iv, plaintext): context = bytes(context_structure_length) ciphertext = bytes(len(plaintext)) aes_encrypt_initialise(key, context) - lib.aes_cbc_encrypt(plaintext, ciphertext, len(plaintext), bytes(bytearray(iv)), context) + lib.aes_cbc_encrypt( + plaintext, ciphertext, len(plaintext), bytes(bytearray(iv)), context + ) return ciphertext def aes_decrypt_initialise(key, context): - if len(key) == (128/8): + if len(key) == (128 / 8): lib.aes_decrypt_key128(key, context) - elif len(key) == (192/8): + elif len(key) == (192 / 8): lib.aes_decrypt_key192(key, context) - elif len(key) == (256/8): + elif len(key) == (256 / 8): lib.aes_decrypt_key256(key, context) else: - raise NotSupported("Unsupported AES key length: {}".format(len(key)*8)) + raise NotSupported("Unsupported AES key length: {}".format(len(key) * 8)) def aes_cbc_decrypt(key, iv, ciphertext): @@ -281,35 +291,43 @@ def generate_aes(filename): data = load_json_testvectors(filename) - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'AES-CBC-PKCS5': + if data["algorithm"] != "AES-CBC-PKCS5": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests"}): raise DataError() - for test in test_group['tests']: - if not keys_in_dict(test, {'key', 'iv', 'msg', 'ct', 'result'}): + for test in test_group["tests"]: + if not keys_in_dict(test, {"key", "iv", "msg", "ct", "result"}): raise DataError() try: - key = unhexlify(test['key']) - iv = unhexlify(test['iv']) - plaintext = unhexlify(test['msg']) - ciphertext = unhexlify(test['ct']) - result = parse_result(test['result']) + key = unhexlify(test["key"]) + iv = unhexlify(test["iv"]) + plaintext = unhexlify(test["msg"]) + ciphertext = unhexlify(test["ct"]) + result = parse_result(test["result"]) except: raise DataError() - if len(key) not in [128/8, 192/8, 256/8]: + if len(key) not in [128 / 8, 192 / 8, 256 / 8]: continue if result is None: continue - vectors.append((hexlify(key), hexlify(iv), hexlify(plaintext), hexlify(ciphertext), result)) + vectors.append( + ( + hexlify(key), + hexlify(iv), + hexlify(plaintext), + hexlify(ciphertext), + result, + ) + ) return vectors @@ -318,34 +336,46 @@ def generate_chacha_poly(filename): data = load_json_testvectors(filename) - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'CHACHA20-POLY1305': + if data["algorithm"] != "CHACHA20-POLY1305": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests"}): raise DataError() - for test in test_group['tests']: - if not keys_in_dict(test, {'key', 'iv', 'aad', 'msg', 'ct', 'tag', 'result'}): + for test in test_group["tests"]: + if not keys_in_dict( + test, {"key", "iv", "aad", "msg", "ct", "tag", "result"} + ): raise DataError() try: - key = unhexlify(test['key']) - iv = unhexlify(test['iv']) - associated_data = unhexlify(test['aad']) - plaintext = unhexlify(test['msg']) - ciphertext = unhexlify(test['ct']) - tag = unhexlify(test['tag']) - result = parse_result(test['result']) + key = unhexlify(test["key"]) + iv = unhexlify(test["iv"]) + associated_data = unhexlify(test["aad"]) + plaintext = unhexlify(test["msg"]) + ciphertext = unhexlify(test["ct"]) + tag = unhexlify(test["tag"]) + result = parse_result(test["result"]) except: raise DataError() if result is None: continue - vectors.append((hexlify(key), hexlify(iv), hexlify(associated_data), hexlify(plaintext), hexlify(ciphertext), hexlify(tag), result)) + vectors.append( + ( + hexlify(key), + hexlify(iv), + hexlify(associated_data), + hexlify(plaintext), + hexlify(ciphertext), + hexlify(tag), + result, + ) + ) return vectors @@ -354,63 +384,70 @@ def generate_curve25519_dh(filename): data = load_json_testvectors(filename) - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'X25519': + if data["algorithm"] != "X25519": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests"}): raise DataError() - for test in test_group['tests']: - if not keys_in_dict(test, {'public', 'private', 'shared', 'result', 'curve'}): + for test in test_group["tests"]: + if not keys_in_dict( + test, {"public", "private", "shared", "result", "curve"} + ): raise DataError() try: - public_key = unhexlify(test['public']) - curve_name = parse_curve_name(test['curve']) - private_key = unhexlify(test['private']) - shared = unhexlify(test['shared']) - result = parse_result(test['result']) + public_key = unhexlify(test["public"]) + curve_name = parse_curve_name(test["curve"]) + private_key = unhexlify(test["private"]) + shared = unhexlify(test["shared"]) + result = parse_result(test["result"]) except: raise DataError() - if curve_name != 'curve25519': + if curve_name != "curve25519": continue if result is None: continue - vectors.append((hexlify(public_key), hexlify(private_key), hexlify(shared), result)) + vectors.append( + (hexlify(public_key), hexlify(private_key), hexlify(shared), result) + ) return vectors + def generate_ecdh(filename): vectors = [] data = load_json_testvectors(filename) - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'ECDH': + if data["algorithm"] != "ECDH": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests"}): raise DataError() - for test in test_group['tests']: - if not keys_in_dict(test, {'public', 'private', 'shared', 'result', 'curve'}): + for test in test_group["tests"]: + if not keys_in_dict( + test, {"public", "private", "shared", "result", "curve"} + ): raise DataError() try: - public_key = unhexlify(test['public']) - curve_name = parse_curve_name(test['curve']) - private_key = parse_signed_hex(test['private']) - shared = unhexlify(test['shared']) - result = parse_result(test['result']) + public_key = unhexlify(test["public"]) + curve_name = parse_curve_name(test["curve"]) + private_key = parse_signed_hex(test["private"]) + shared = unhexlify(test["shared"]) + result = parse_result(test["result"]) except: raise DataError() @@ -431,7 +468,15 @@ def generate_ecdh(filename): if result is None: continue - vectors.append((curve_name, hexlify(public_key), hexlify(private_key), hexlify(shared), result)) + vectors.append( + ( + curve_name, + hexlify(public_key), + hexlify(private_key), + hexlify(shared), + result, + ) + ) return vectors @@ -441,18 +486,18 @@ def generate_ecdsa(filename): data = load_json_testvectors(filename) - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'ECDSA': + if data["algorithm"] != "ECDSA": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests', 'keyDer', 'sha'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests", "keyDer", "sha"}): raise DataError() try: - public_key = unhexlify(test_group['keyDer']) + public_key = unhexlify(test_group["keyDer"]) except: raise DataError() @@ -464,18 +509,18 @@ def generate_ecdsa(filename): continue try: - hasher = parse_digest(test_group['sha']) + hasher = parse_digest(test_group["sha"]) except NotSupported: continue - for test in test_group['tests']: - if not keys_in_dict(test, {'sig', 'msg', 'result'}): + for test in test_group["tests"]: + if not keys_in_dict(test, {"sig", "msg", "result"}): raise DataError() try: - signature = unhexlify(test['sig']) - message = unhexlify(test['msg']) - result = parse_result(test['result']) + signature = unhexlify(test["sig"]) + message = unhexlify(test["msg"]) + result = parse_result(test["result"]) except: raise DataError() @@ -487,7 +532,16 @@ def generate_ecdsa(filename): except ParseError: continue - vectors.append((curve_name, hexlify(public_key), hasher, hexlify(message), hexlify(signature), result)) + vectors.append( + ( + curve_name, + hexlify(public_key), + hasher, + hexlify(message), + hexlify(signature), + result, + ) + ) return vectors @@ -497,19 +551,18 @@ def generate_eddsa(filename): data = load_json_testvectors(filename) - - if not keys_in_dict(data, {'algorithm', 'testGroups'}): + if not keys_in_dict(data, {"algorithm", "testGroups"}): raise DataError() - if data['algorithm'] != 'EDDSA': + if data["algorithm"] != "EDDSA": raise DataError() - for test_group in data['testGroups']: - if not keys_in_dict(test_group, {'tests', 'keyDer'}): + for test_group in data["testGroups"]: + if not keys_in_dict(test_group, {"tests", "keyDer"}): raise DataError() try: - public_key = unhexlify(test_group['keyDer']) + public_key = unhexlify(test_group["keyDer"]) except: raise DataError() @@ -518,14 +571,14 @@ def generate_eddsa(filename): except ParseError: continue - for test in test_group['tests']: - if not keys_in_dict(test, {'sig', 'msg', 'result'}): + for test in test_group["tests"]: + if not keys_in_dict(test, {"sig", "msg", "result"}): raise DataError() try: - signature = unhexlify(test['sig']) - message = unhexlify(test['msg']) - result = parse_result(test['result']) + signature = unhexlify(test["sig"]) + message = unhexlify(test["msg"]) + result = parse_result(test["result"]) except: raise DataError() @@ -537,21 +590,31 @@ def generate_eddsa(filename): except ParseError: continue - vectors.append((hexlify(public_key), hexlify(message), hexlify(signature), result)) + vectors.append( + (hexlify(public_key), hexlify(message), hexlify(signature), result) + ) return vectors dir = os.path.abspath(os.path.dirname(__file__)) -lib = ctypes.cdll.LoadLibrary(os.path.join(dir, 'libtrezor-crypto.so')) -testvectors_directory = os.path.join(dir, 'wycheproof/testvectors') +lib = ctypes.cdll.LoadLibrary(os.path.join(dir, "libtrezor-crypto.so")) +testvectors_directory = os.path.join(dir, "wycheproof/testvectors") context_structure_length = 1024 ecdh_vectors = generate_ecdh("ecdh_test.json") curve25519_dh_vectors = generate_curve25519_dh("x25519_test.json") eddsa_vectors = generate_eddsa("eddsa_test.json") -ecdsa_vectors = generate_ecdsa("ecdsa_test.json") + generate_ecdsa("ecdsa_secp256k1_sha256_test.json") + generate_ecdsa("ecdsa_secp256r1_sha256_test.json") -ecdh_vectors = generate_ecdh("ecdh_test.json") + generate_ecdh("ecdh_secp256k1_test.json") + generate_ecdh("ecdh_secp256r1_test.json") +ecdsa_vectors = ( + generate_ecdsa("ecdsa_test.json") + + generate_ecdsa("ecdsa_secp256k1_sha256_test.json") + + generate_ecdsa("ecdsa_secp256r1_sha256_test.json") +) +ecdh_vectors = ( + generate_ecdh("ecdh_test.json") + + generate_ecdh("ecdh_secp256k1_test.json") + + generate_ecdh("ecdh_secp256r1_test.json") +) chacha_poly_vectors = generate_chacha_poly("chacha20_poly1305_test.json") aes_vectors = generate_aes("aes_cbc_pkcs5_test.json") @@ -562,37 +625,48 @@ def test_eddsa(public_key, message, signature, result): signature = unhexlify(signature) message = unhexlify(message) - computed_result = lib.ed25519_sign_open(message, len(message), public_key, signature) == 0 + computed_result = ( + lib.ed25519_sign_open(message, len(message), public_key, signature) == 0 + ) assert result == computed_result -@pytest.mark.parametrize("curve_name, public_key, hasher, message, signature, result", ecdsa_vectors) +@pytest.mark.parametrize( + "curve_name, public_key, hasher, message, signature, result", ecdsa_vectors +) def test_ecdsa(curve_name, public_key, hasher, message, signature, result): curve = get_curve_by_name(curve_name) if curve is None: - raise NotSupported("Curve not supported: {}".format(curve_Name)) + raise NotSupported("Curve not supported: {}".format(curve_name)) public_key = unhexlify(public_key) signature = unhexlify(signature) message = unhexlify(message) - computed_result = lib.ecdsa_verify(curve, hasher, public_key, signature, message, len(message)) == 0 + computed_result = ( + lib.ecdsa_verify(curve, hasher, public_key, signature, message, len(message)) + == 0 + ) assert result == computed_result -@pytest.mark.parametrize("public_key, private_key, shared, result", curve25519_dh_vectors) +@pytest.mark.parametrize( + "public_key, private_key, shared, result", curve25519_dh_vectors +) def test_curve25519_dh(public_key, private_key, shared, result): public_key = unhexlify(public_key) private_key = unhexlify(private_key) shared = unhexlify(shared) - computed_shared = bytes([0]*32) + computed_shared = bytes([0] * 32) lib.curve25519_scalarmult(computed_shared, private_key, public_key) computed_result = shared == computed_shared assert result == computed_result -@pytest.mark.parametrize("curve_name, public_key, private_key, shared, result", ecdh_vectors) +@pytest.mark.parametrize( + "curve_name, public_key, private_key, shared, result", ecdh_vectors +) def test_ecdh(curve_name, public_key, private_key, shared, result): curve = get_curve_by_name(curve_name) if curve is None: @@ -602,14 +676,16 @@ def test_ecdh(curve_name, public_key, private_key, shared, result): private_key = unhexlify(private_key) shared = unhexlify(shared) - computed_shared = bytes([0]*2*32) + computed_shared = bytes([0] * 2 * 32) lib.ecdh_multiply(curve, private_key, public_key, computed_shared) computed_shared = computed_shared[1:33] computed_result = shared == computed_shared assert result == computed_result -@pytest.mark.parametrize("key, iv, associated_data, plaintext, ciphertext, tag, result", chacha_poly_vectors) +@pytest.mark.parametrize( + "key, iv, associated_data, plaintext, ciphertext, tag, result", chacha_poly_vectors +) def test_chacha_poly(key, iv, associated_data, plaintext, ciphertext, tag, result): key = unhexlify(key) iv = unhexlify(iv) @@ -618,7 +694,9 @@ def test_chacha_poly(key, iv, associated_data, plaintext, ciphertext, tag, resul ciphertext = unhexlify(ciphertext) tag = unhexlify(tag) - computed_ciphertext, computed_tag = chacha_poly_encrypt(key, iv, associated_data, plaintext) + computed_ciphertext, computed_tag = chacha_poly_encrypt( + key, iv, associated_data, plaintext + ) computed_result = ciphertext == computed_ciphertext and tag == computed_tag assert result == computed_result