1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-29 03:12:34 +00:00

refactor(core): introduce compressed in verify_recover()

[no changelog]
This commit is contained in:
Ondřej Vejpustek 2025-03-25 22:03:53 +01:00
parent 8c807a16b4
commit 6f2130f1ee
8 changed files with 30 additions and 25 deletions

View File

@ -151,17 +151,18 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify(mp_obj_t public_key,
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_nist256p1_verify_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_nist256p1_verify_obj,
mod_trezorcrypto_nist256p1_verify); mod_trezorcrypto_nist256p1_verify);
/// def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
/// def verify_recover(signature: bytes, digest: bytes) -> bytes: /// -> bytes:
/// """ /// """
/// Uses signature of the digest to verify the digest and recover the public /// Uses signature of the digest to verify the digest and recover the public
/// key. Returns public key on success, None if the signature is invalid. /// key. Returns public key on success, None if the signature is invalid.
/// """ /// """
STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature, STATIC mp_obj_t
mp_obj_t digest) { mod_trezorcrypto_nist256p1_verify_recover(size_t n_args, const mp_obj_t *args) {
mp_buffer_info_t sig = {0}, dig = {0}; mp_buffer_info_t sig = {0}, dig = {0};
mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ); mp_get_buffer_raise(args[0], &sig, MP_BUFFER_READ);
mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); mp_get_buffer_raise(args[1], &dig, MP_BUFFER_READ);
bool compressed = n_args < 3 || args[2] == mp_const_true;
if (sig.len != 65) { if (sig.len != 65) {
return mp_const_none; return mp_const_none;
} }
@ -172,7 +173,6 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature,
if (recid >= 8) { if (recid >= 8) {
return mp_const_none; return mp_const_none;
} }
bool compressed = (recid >= 4);
recid &= 3; recid &= 3;
vstr_t pk = {0}; vstr_t pk = {0};
vstr_init_len(&pk, 65); vstr_init_len(&pk, 65);
@ -188,8 +188,9 @@ STATIC mp_obj_t mod_trezorcrypto_nist256p1_verify_recover(mp_obj_t signature,
return mp_const_none; return mp_const_none;
} }
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_nist256p1_verify_recover_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorcrypto_nist256p1_verify_recover); mod_trezorcrypto_nist256p1_verify_recover_obj, 2, 3,
mod_trezorcrypto_nist256p1_verify_recover);
/// def multiply(secret_key: bytes, public_key: bytes) -> bytes: /// def multiply(secret_key: bytes, public_key: bytes) -> bytes:
/// """ /// """

View File

@ -194,16 +194,18 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify(mp_obj_t public_key,
STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_secp256k1_verify_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_3(mod_trezorcrypto_secp256k1_verify_obj,
mod_trezorcrypto_secp256k1_verify); mod_trezorcrypto_secp256k1_verify);
/// def verify_recover(signature: bytes, digest: bytes) -> bytes: /// def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
/// -> bytes:
/// """ /// """
/// Uses signature of the digest to verify the digest and recover the public /// Uses signature of the digest to verify the digest and recover the public
/// key. Returns public key on success, None if the signature is invalid. /// key. Returns public key on success, None if the signature is invalid.
/// """ /// """
STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature, STATIC mp_obj_t
mp_obj_t digest) { mod_trezorcrypto_secp256k1_verify_recover(size_t n_args, const mp_obj_t *args) {
mp_buffer_info_t sig = {0}, dig = {0}; mp_buffer_info_t sig = {0}, dig = {0};
mp_get_buffer_raise(signature, &sig, MP_BUFFER_READ); mp_get_buffer_raise(args[0], &sig, MP_BUFFER_READ);
mp_get_buffer_raise(digest, &dig, MP_BUFFER_READ); mp_get_buffer_raise(args[1], &dig, MP_BUFFER_READ);
bool compressed = n_args < 3 || args[2] == mp_const_true;
if (sig.len != 65) { if (sig.len != 65) {
return mp_const_none; return mp_const_none;
} }
@ -214,7 +216,6 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature,
if (recid >= 8) { if (recid >= 8) {
return mp_const_none; return mp_const_none;
} }
bool compressed = (recid >= 4);
recid &= 3; recid &= 3;
vstr_t pk = {0}; vstr_t pk = {0};
vstr_init_len(&pk, 65); vstr_init_len(&pk, 65);
@ -230,9 +231,9 @@ STATIC mp_obj_t mod_trezorcrypto_secp256k1_verify_recover(mp_obj_t signature,
return mp_const_none; return mp_const_none;
} }
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_secp256k1_verify_recover_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(
mod_trezorcrypto_secp256k1_verify_recover); mod_trezorcrypto_secp256k1_verify_recover_obj, 2, 3,
mod_trezorcrypto_secp256k1_verify_recover);
/// def multiply(secret_key: bytes, public_key: bytes) -> bytes: /// def multiply(secret_key: bytes, public_key: bytes) -> bytes:
/// """ /// """
/// Multiplies point defined by public_key with scalar defined by /// Multiplies point defined by public_key with scalar defined by

View File

@ -33,7 +33,8 @@ def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool:
# upymod/modtrezorcrypto/modtrezorcrypto-nist256p1.h # upymod/modtrezorcrypto/modtrezorcrypto-nist256p1.h
def verify_recover(signature: bytes, digest: bytes) -> bytes: def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
-> bytes:
""" """
Uses signature of the digest to verify the digest and recover the public Uses signature of the digest to verify the digest and recover the public
key. Returns public key on success, None if the signature is invalid. key. Returns public key on success, None if the signature is invalid.

View File

@ -38,7 +38,8 @@ def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool:
# upymod/modtrezorcrypto/modtrezorcrypto-secp256k1.h # upymod/modtrezorcrypto/modtrezorcrypto-secp256k1.h
def verify_recover(signature: bytes, digest: bytes) -> bytes: def verify_recover(signature: bytes, digest: bytes, compressed: bool = True)
-> bytes:
""" """
Uses signature of the digest to verify the digest and recover the public Uses signature of the digest to verify the digest and recover the public
key. Returns public key on success, None if the signature is invalid. key. Returns public key on success, None if the signature is invalid.

View File

@ -100,6 +100,7 @@ async def verify_message(msg: VerifyMessage) -> Success:
pubkey = secp256k1.verify_recover( pubkey = secp256k1.verify_recover(
recoverable_signature, recoverable_signature,
digest, digest,
signature_script_type != InputScriptType.SPENDADDRESS_UNCOMPRESSED,
) )
if not pubkey: if not pubkey:

View File

@ -22,7 +22,7 @@ async def verify_message(msg: EthereumVerifyMessage) -> Success:
raise DataError("Invalid signature") raise DataError("Invalid signature")
sig = decode_signature(msg.signature) sig = decode_signature(msg.signature)
pubkey = secp256k1.verify_recover(sig, digest) pubkey = secp256k1.verify_recover(sig, digest, False)
if not pubkey: if not pubkey:
raise DataError("Invalid signature") raise DataError("Invalid signature")

View File

@ -279,8 +279,8 @@ class TestCryptoNist256p1(unittest.TestCase):
sk = nist256p1.generate_secret() sk = nist256p1.generate_secret()
pk = nist256p1.publickey(sk, compressed) pk = nist256p1.publickey(sk, compressed)
dig = random.bytes(32) dig = random.bytes(32)
sig = nist256p1.sign_recoverable(sk, dig, compressed) sig = nist256p1.sign_recoverable(sk, dig)
pk2 = nist256p1.verify_recover(sig, dig) pk2 = nist256p1.verify_recover(sig, dig, compressed)
self.assertEqual(pk, pk2) self.assertEqual(pk, pk2)

View File

@ -246,8 +246,8 @@ class TestCryptoSecp256k1(unittest.TestCase):
sk = secp256k1.generate_secret() sk = secp256k1.generate_secret()
pk = secp256k1.publickey(sk, compressed) pk = secp256k1.publickey(sk, compressed)
dig = random.bytes(32) dig = random.bytes(32)
sig = secp256k1.sign_recoverable(sk, dig, compressed) sig = secp256k1.sign_recoverable(sk, dig)
pk2 = secp256k1.verify_recover(sig, dig) pk2 = secp256k1.verify_recover(sig, dig, compressed)
self.assertEqual(pk, pk2) self.assertEqual(pk, pk2)
def test_ecdh(self): def test_ecdh(self):