add curve25519.publickey and unittest for randomized multiply

pull/25/head
Pavol Rusnak 8 years ago
parent 7f5fa78f35
commit 3c5c685b8c
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -20,6 +20,23 @@ STATIC mp_obj_t mod_TrezorCrypto_Curve25519_make_new(const mp_obj_type_t *type,
return MP_OBJ_FROM_PTR(o);
}
/// def trezor.crypto.curve.curve25519.publickey(secret_key: bytes) -> bytes:
/// '''
/// Computes public key from secret key.
/// '''
STATIC mp_obj_t mod_TrezorCrypto_Curve25519_publickey(mp_obj_t self, mp_obj_t secret_key) {
mp_buffer_info_t sk;
mp_get_buffer_raise(secret_key, &sk, MP_BUFFER_READ);
if (sk.len != 32) {
mp_raise_ValueError("Invalid length of secret key");
}
vstr_t vstr;
vstr_init_len(&vstr, 32);
curve25519_publickey((uint8_t *)vstr.buf, (const uint8_t *)sk.buf);
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_TrezorCrypto_Curve25519_publickey_obj, mod_TrezorCrypto_Curve25519_publickey);
/// def trezor.crypto.curve.curve25519.multiply(secret_key: bytes, public_key: bytes) -> bytes:
/// '''
/// Multiplies point defined by public_key with scalar defined by secret_key
@ -43,6 +60,7 @@ STATIC mp_obj_t mod_TrezorCrypto_Curve25519_multiply(mp_obj_t self, mp_obj_t sec
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_TrezorCrypto_Curve25519_multiply_obj, mod_TrezorCrypto_Curve25519_multiply);
STATIC const mp_rom_map_elem_t mod_TrezorCrypto_Curve25519_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_publickey), MP_ROM_PTR(&mod_TrezorCrypto_Curve25519_publickey_obj) },
{ MP_ROM_QSTR(MP_QSTR_multiply), MP_ROM_PTR(&mod_TrezorCrypto_Curve25519_multiply_obj) },
};
STATIC MP_DEFINE_CONST_DICT(mod_TrezorCrypto_Curve25519_locals_dict, mod_TrezorCrypto_Curve25519_locals_dict_table);

@ -5,6 +5,7 @@ import unittest
from ubinascii import unhexlify
from trezor.crypto.curve import curve25519
from trezor.crypto import random
class TestCryptoCurve25519(unittest.TestCase):
@ -17,5 +18,18 @@ class TestCryptoCurve25519(unittest.TestCase):
session2 = curve25519.multiply(unhexlify(sk), unhexlify(pk))
self.assertEqual(session2, unhexlify(session))
def test_multiply_random(self):
for _ in range(100):
sk1 = bytearray(random.bytes(32))
sk2 = bytearray(random.bytes(32))
# taken from https://cr.yp.to/ecdh.html
sk1[0] &= 248 ; sk1[31] &= 127 ; sk1[31] |= 64
sk2[0] &= 248 ; sk2[31] &= 127 ; sk2[31] |= 64
pk1 = curve25519.publickey(sk1)
pk2 = curve25519.publickey(sk2)
session1 = curve25519.multiply(sk1, pk2)
session2 = curve25519.multiply(sk2, pk1)
self.assertEqual(session1, session2)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save