diff --git a/extmod/modtrezorcrypto/modtrezorcrypto-nist256p1.h b/extmod/modtrezorcrypto/modtrezorcrypto-nist256p1.h index 2cfa2f6b49..02b90bbf4e 100644 --- a/extmod/modtrezorcrypto/modtrezorcrypto-nist256p1.h +++ b/extmod/modtrezorcrypto/modtrezorcrypto-nist256p1.h @@ -62,14 +62,15 @@ STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_publickey(size_t n_args, const mp_obj } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Nist256p1_publickey_obj, 2, 3, mod_TrezorCrypto_Nist256p1_publickey); -/// def trezor.crypto.curve.nist256p1.sign(secret_key: bytes, digest: bytes) -> bytes: +/// def trezor.crypto.curve.nist256p1.sign(secret_key: bytes, digest: bytes, compressed: bool=True) -> bytes: /// ''' /// Uses secret key to produce the signature of the digest. /// ''' -STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_sign(mp_obj_t self, mp_obj_t secret_key, mp_obj_t digest) { +STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_sign(size_t n_args, const mp_obj_t *args) { mp_buffer_info_t sk, dig; - mp_get_buffer_raise(secret_key, &sk, MP_BUFFER_READ); - mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); + mp_get_buffer_raise(args[1], &sk, MP_BUFFER_READ); + mp_get_buffer_raise(args[2], &dig, MP_BUFFER_READ); + bool compressed = n_args > 3 && args[3] == mp_const_true; if (sk.len != 32) { mp_raise_ValueError("Invalid length of secret key"); } @@ -79,13 +80,13 @@ STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_sign(mp_obj_t self, mp_obj_t secret_k vstr_t vstr; vstr_init_len(&vstr, 65); uint8_t pby; - if (0 != ecdsa_sign_digest(&nist256p1, (const uint8_t *)sk.buf, (const uint8_t *)dig.buf, (uint8_t *)vstr.buf + 1, &pby, NULL)) { // TODO: is_canonical + if (0 != ecdsa_sign_digest(&nist256p1, (const uint8_t *)sk.buf, (const uint8_t *)dig.buf, (uint8_t *)vstr.buf + 1, &pby, NULL)) { mp_raise_ValueError("Signing failed"); } - vstr.buf[0] = 27 + pby + 4; + vstr.buf[0] = 27 + pby + compressed * 4; return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); } -STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_TrezorCrypto_Nist256p1_sign_obj, mod_TrezorCrypto_Nist256p1_sign); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Nist256p1_sign_obj, 3, 4, mod_TrezorCrypto_Nist256p1_sign); /// def trezor.crypto.curve.nist256p1.verify(public_key: bytes, signature: bytes, digest: bytes) -> bool: /// ''' @@ -111,6 +112,41 @@ STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_verify(size_t n_args, const mp_obj_t } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Nist256p1_verify_obj, 4, 4, mod_TrezorCrypto_Nist256p1_verify); +/// def trezor.crypto.curve.nist256p1.verify_recover(signature: bytes, digest: bytes) -> bytes: +/// ''' +/// Uses signature of the digest to verify the digest and recover the public key. +/// Returns public key on success, None on failure. +/// ''' +STATIC mp_obj_t mod_TrezorCrypto_Nist256p1_verify_recover(mp_obj_t self, mp_obj_t signature, mp_obj_t digest) { + mp_buffer_info_t sig, dig; + mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ); + mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); + if (sig.len != 65) { + mp_raise_ValueError("Invalid length of signature"); + } + if (dig.len != 32) { + mp_raise_ValueError("Invalid length of digest"); + } + uint8_t recid = ((const uint8_t *)sig.buf)[0] - 27; + if (recid >= 8) { + mp_raise_ValueError("Invalid recid in signature"); + } + bool compressed = (recid >= 4); + recid &= 3; + vstr_t vstr; + vstr_init_len(&vstr, 65); + if (0 == ecdsa_verify_digest_recover(&nist256p1, (uint8_t *)vstr.buf, (const uint8_t *)sig.buf + 1, (const uint8_t *)dig.buf, recid)) { + if (compressed) { + vstr.buf[0] = 0x02 | (vstr.buf[64] & 1); + vstr.len = 33; + } + return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); + } else { + return mp_const_none; + } +} +STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_TrezorCrypto_Nist256p1_verify_recover_obj, mod_TrezorCrypto_Nist256p1_verify_recover); + /// def trezor.crypto.curve.nist256p1.multiply(secret_key: bytes, public_key: bytes) -> bytes: /// ''' /// Multiplies point defined by public_key with scalar defined by secret_key @@ -140,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_TrezorCrypto_Nist256p1_locals_dict_table[] = { MP_ROM_QSTR(MP_QSTR_publickey), MP_ROM_PTR(&mod_TrezorCrypto_Nist256p1_publickey_obj) }, { MP_ROM_QSTR(MP_QSTR_sign), MP_ROM_PTR(&mod_TrezorCrypto_Nist256p1_sign_obj) }, { MP_ROM_QSTR(MP_QSTR_verify), MP_ROM_PTR(&mod_TrezorCrypto_Nist256p1_verify_obj) }, + { MP_ROM_QSTR(MP_QSTR_verify_recover), MP_ROM_PTR(&mod_TrezorCrypto_Nist256p1_verify_recover_obj) }, { MP_ROM_QSTR(MP_QSTR_multiply), MP_ROM_PTR(&mod_TrezorCrypto_Nist256p1_multiply_obj) }, }; STATIC MP_DEFINE_CONST_DICT(mod_TrezorCrypto_Nist256p1_locals_dict, mod_TrezorCrypto_Nist256p1_locals_dict_table); diff --git a/extmod/modtrezorcrypto/modtrezorcrypto-secp256k1.h b/extmod/modtrezorcrypto/modtrezorcrypto-secp256k1.h index 45d98e12fe..3f8ed2412b 100644 --- a/extmod/modtrezorcrypto/modtrezorcrypto-secp256k1.h +++ b/extmod/modtrezorcrypto/modtrezorcrypto-secp256k1.h @@ -62,14 +62,15 @@ STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_publickey(size_t n_args, const mp_obj } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Secp256k1_publickey_obj, 2, 3, mod_TrezorCrypto_Secp256k1_publickey); -/// def trezor.crypto.curve.secp256k1.sign(secret_key: bytes, digest: bytes) -> bytes: +/// def trezor.crypto.curve.secp256k1.sign(secret_key: bytes, digest: bytes, compressed: bool=True) -> bytes: /// ''' /// Uses secret key to produce the signature of the digest. /// ''' -STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_sign(mp_obj_t self, mp_obj_t secret_key, mp_obj_t digest) { +STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_sign(size_t n_args, const mp_obj_t *args) { mp_buffer_info_t sk, dig; - mp_get_buffer_raise(secret_key, &sk, MP_BUFFER_READ); - mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); + mp_get_buffer_raise(args[1], &sk, MP_BUFFER_READ); + mp_get_buffer_raise(args[2], &dig, MP_BUFFER_READ); + bool compressed = n_args > 3 && args[3] == mp_const_true; if (sk.len != 32) { mp_raise_ValueError("Invalid length of secret key"); } @@ -79,13 +80,13 @@ STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_sign(mp_obj_t self, mp_obj_t secret_k vstr_t vstr; vstr_init_len(&vstr, 65); uint8_t pby; - if (0 != ecdsa_sign_digest(&secp256k1, (const uint8_t *)sk.buf, (const uint8_t *)dig.buf, (uint8_t *)vstr.buf + 1, &pby, NULL)) { // TODO: is_canonical + if (0 != ecdsa_sign_digest(&secp256k1, (const uint8_t *)sk.buf, (const uint8_t *)dig.buf, (uint8_t *)vstr.buf + 1, &pby, NULL)) { mp_raise_ValueError("Signing failed"); } - vstr.buf[0] = 27 + pby + 4; + vstr.buf[0] = 27 + pby + compressed * 4; return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); } -STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_TrezorCrypto_Secp256k1_sign_obj, mod_TrezorCrypto_Secp256k1_sign); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Secp256k1_sign_obj, 3, 4, mod_TrezorCrypto_Secp256k1_sign); /// def trezor.crypto.curve.secp256k1.verify(public_key: bytes, signature: bytes, digest: bytes) -> bool: /// ''' @@ -111,6 +112,41 @@ STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_verify(size_t n_args, const mp_obj_t } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_TrezorCrypto_Secp256k1_verify_obj, 4, 4, mod_TrezorCrypto_Secp256k1_verify); +/// def trezor.crypto.curve.secp256k1.verify_recover(signature: bytes, digest: bytes) -> bytes: +/// ''' +/// Uses signature of the digest to verify the digest and recover the public key. +/// Returns public key on success, None on failure. +/// ''' +STATIC mp_obj_t mod_TrezorCrypto_Secp256k1_verify_recover(mp_obj_t self, mp_obj_t signature, mp_obj_t digest) { + mp_buffer_info_t sig, dig; + mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ); + mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); + if (sig.len != 65) { + mp_raise_ValueError("Invalid length of signature"); + } + if (dig.len != 32) { + mp_raise_ValueError("Invalid length of digest"); + } + uint8_t recid = ((const uint8_t *)sig.buf)[0] - 27; + if (recid >= 8) { + mp_raise_ValueError("Invalid recid in signature"); + } + bool compressed = (recid >= 4); + recid &= 3; + vstr_t vstr; + vstr_init_len(&vstr, 65); + if (0 == ecdsa_verify_digest_recover(&secp256k1, (uint8_t *)vstr.buf, (const uint8_t *)sig.buf + 1, (const uint8_t *)dig.buf, recid)) { + if (compressed) { + vstr.buf[0] = 0x02 | (vstr.buf[64] & 1); + vstr.len = 33; + } + return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); + } else { + return mp_const_none; + } +} +STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_TrezorCrypto_Secp256k1_verify_recover_obj, mod_TrezorCrypto_Secp256k1_verify_recover); + /// def trezor.crypto.curve.secp256k1.multiply(secret_key: bytes, public_key: bytes) -> bytes: /// ''' /// Multiplies point defined by public_key with scalar defined by secret_key @@ -140,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_TrezorCrypto_Secp256k1_locals_dict_table[] = { MP_ROM_QSTR(MP_QSTR_publickey), MP_ROM_PTR(&mod_TrezorCrypto_Secp256k1_publickey_obj) }, { MP_ROM_QSTR(MP_QSTR_sign), MP_ROM_PTR(&mod_TrezorCrypto_Secp256k1_sign_obj) }, { MP_ROM_QSTR(MP_QSTR_verify), MP_ROM_PTR(&mod_TrezorCrypto_Secp256k1_verify_obj) }, + { MP_ROM_QSTR(MP_QSTR_verify_recover), MP_ROM_PTR(&mod_TrezorCrypto_Secp256k1_verify_recover_obj) }, { MP_ROM_QSTR(MP_QSTR_multiply), MP_ROM_PTR(&mod_TrezorCrypto_Secp256k1_multiply_obj) }, }; STATIC MP_DEFINE_CONST_DICT(mod_TrezorCrypto_Secp256k1_locals_dict, mod_TrezorCrypto_Secp256k1_locals_dict_table); diff --git a/tests/test_trezor.crypto.curve.nist256p1.py b/tests/test_trezor.crypto.curve.nist256p1.py index e72cbd8917..a1e2579026 100644 --- a/tests/test_trezor.crypto.curve.nist256p1.py +++ b/tests/test_trezor.crypto.curve.nist256p1.py @@ -111,5 +111,15 @@ class TestCryptoNist256p1(unittest.TestCase): self.assertTrue(nist256p1.verify(pk, sig, dig)) self.assertTrue(nist256p1.verify(pk, sig[1:], dig)) + def test_verify_recover(self): + for compressed in [False, True]: + for _ in range(100): + sk = nist256p1.generate_secret() + pk = nist256p1.publickey(sk, compressed) + dig = random.bytes(32) + sig = nist256p1.sign(sk, dig, compressed) + pk2 = nist256p1.verify_recover(sig, dig) + self.assertEqual(pk, pk2) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_trezor.crypto.curve.secp256k1.py b/tests/test_trezor.crypto.curve.secp256k1.py index 6ceff9e70c..63c672d046 100644 --- a/tests/test_trezor.crypto.curve.secp256k1.py +++ b/tests/test_trezor.crypto.curve.secp256k1.py @@ -100,5 +100,15 @@ class TestCryptoSecp256k1(unittest.TestCase): sig = secp256k1.sign(sk, dig) self.assertTrue(secp256k1.verify(pk, sig, dig)) + def test_verify_recover(self): + for compressed in [False, True]: + for _ in range(100): + sk = secp256k1.generate_secret() + pk = secp256k1.publickey(sk, compressed) + dig = random.bytes(32) + sig = secp256k1.sign(sk, dig, compressed) + pk2 = secp256k1.verify_recover(sig, dig) + self.assertEqual(pk, pk2) + if __name__ == '__main__': unittest.main()