1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-20 09:09:02 +00:00

refactor(core): introduce compressed in verify_recover()

[no changelog]
This commit is contained in:
Ondřej Vejpustek 2025-03-25 22:03:53 +01:00
parent 8c807a16b4
commit 6f2130f1ee
8 changed files with 30 additions and 25 deletions

View File

@ -151,17 +151,18 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify(mp_obj_t public_key,
}
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_nist256p1_verify_obj,
mod_trezorcrypto_nist256p1_verify);
/// def verify_recover(signature: bytes, digest: bytes) -> bytes:
/// def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
/// -> bytes:
/// """
/// Uses signature of the digest to verify the digest and recover the public
/// key. Returns public key on success, None if the signature is invalid.
/// """
STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature,
mp_obj_t digest) {
STATIC mp_obj_t
mod_trezorcrypto_nist256p1_verify_recover(size_t n_args, const mp_obj_t *args) {
mp_buffer_info_t sig = {0}, dig = {0};
mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ);
mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ);
mp_get_buffer_raise(args[0], &sig, MP_BUFFER_READ);
mp_get_buffer_raise(args[1], &dig, MP_BUFFER_READ);
bool compressed = n_args < 3 || args[2] == mp_const_true;
if (sig.len != 65) {
return mp_const_none;
}
@ -172,7 +173,6 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature,
if (recid >= 8) {
return mp_const_none;
}
bool compressed = (recid >= 4);
recid &= 3;
vstr_t pk = {0};
vstr_init_len(&pk, 65);
@ -188,8 +188,9 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature,
return mp_const_none;
}
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_nist256p1_verify_recover_obj,
mod_trezorcrypto_nist256p1_verify_recover);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorcrypto_nist256p1_verify_recover_obj, 2, 3,
mod_trezorcrypto_nist256p1_verify_recover);
/// def multiply(secret_key: bytes, public_key: bytes) -> bytes:
/// """

View File

@ -194,16 +194,18 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify(mp_obj_t public_key,
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_secp256k1_verify_obj,
mod_trezorcrypto_secp256k1_verify);
/// def verify_recover(signature: bytes, digest: bytes) -> bytes:
/// def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
/// -> bytes:
/// """
/// Uses signature of the digest to verify the digest and recover the public
/// key. Returns public key on success, None if the signature is invalid.
/// """
STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature,
mp_obj_t digest) {
STATIC mp_obj_t
mod_trezorcrypto_secp256k1_verify_recover(size_t n_args, const mp_obj_t *args) {
mp_buffer_info_t sig = {0}, dig = {0};
mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ);
mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ);
mp_get_buffer_raise(args[0], &sig, MP_BUFFER_READ);
mp_get_buffer_raise(args[1], &dig, MP_BUFFER_READ);
bool compressed = n_args < 3 || args[2] == mp_const_true;
if (sig.len != 65) {
return mp_const_none;
}
@ -214,7 +216,6 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature,
if (recid >= 8) {
return mp_const_none;
}
bool compressed = (recid >= 4);
recid &= 3;
vstr_t pk = {0};
vstr_init_len(&pk, 65);
@ -230,9 +231,9 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature,
return mp_const_none;
}
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_secp256k1_verify_recover_obj,
mod_trezorcrypto_secp256k1_verify_recover);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorcrypto_secp256k1_verify_recover_obj, 2, 3,
mod_trezorcrypto_secp256k1_verify_recover);
/// def multiply(secret_key: bytes, public_key: bytes) -> bytes:
/// """
/// Multiplies point defined by public_key with scalar defined by

View File

@ -33,7 +33,8 @@ def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool:
# upymod/modtrezorcrypto/modtrezorcrypto-nist256p1.h
def verify_recover(signature: bytes, digest: bytes) -> bytes:
def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
-> bytes:
"""
Uses signature of the digest to verify the digest and recover the public
key. Returns public key on success, None if the signature is invalid.

View File

@ -38,7 +38,8 @@ def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool:
# upymod/modtrezorcrypto/modtrezorcrypto-secp256k1.h
def verify_recover(signature: bytes, digest: bytes) -> bytes:
def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
-> bytes:
"""
Uses signature of the digest to verify the digest and recover the public
key. Returns public key on success, None if the signature is invalid.

View File

@ -100,6 +100,7 @@ async def verify_message(msg: VerifyMessage) -> Success:
pubkey = secp256k1.verify_recover(
recoverable_signature,
digest,
signature_script_type != InputScriptType.SPENDADDRESS_UNCOMPRESSED,
)
if not pubkey:

View File

@ -22,7 +22,7 @@ async def verify_message(msg: EthereumVerifyMessage) -> Success:
raise DataError("Invalid signature")
sig = decode_signature(msg.signature)
pubkey = secp256k1.verify_recover(sig, digest)
pubkey = secp256k1.verify_recover(sig, digest, False)
if not pubkey:
raise DataError("Invalid signature")

View File

@ -279,8 +279,8 @@ class TestCryptoNist256p1(unittest.TestCase):
sk = nist256p1.generate_secret()
pk = nist256p1.publickey(sk, compressed)
dig = random.bytes(32)
sig = nist256p1.sign_recoverable(sk, dig, compressed)
pk2 = nist256p1.verify_recover(sig, dig)
sig = nist256p1.sign_recoverable(sk, dig)
pk2 = nist256p1.verify_recover(sig, dig, compressed)
self.assertEqual(pk, pk2)

View File

@ -246,8 +246,8 @@ class TestCryptoSecp256k1(unittest.TestCase):
sk = secp256k1.generate_secret()
pk = secp256k1.publickey(sk, compressed)
dig = random.bytes(32)
sig = secp256k1.sign_recoverable(sk, dig, compressed)
pk2 = secp256k1.verify_recover(sig, dig)
sig = secp256k1.sign_recoverable(sk, dig)
pk2 = secp256k1.verify_recover(sig, dig, compressed)
self.assertEqual(pk, pk2)
def test_ecdh(self):