diff --git a/core/.changelog.d/2077.added b/core/.changelog.d/2077.added new file mode 100644 index 000000000..6ca321ffa --- /dev/null +++ b/core/.changelog.d/2077.added @@ -0,0 +1 @@ +Add extra check for Taproot scripts validity diff --git a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip340.h b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip340.h index 8f1cc5201..fb036b419 100644 --- a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip340.h +++ b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip340.h @@ -113,6 +113,23 @@ STATIC mp_obj_t mod_trezorcrypto_bip340_sign(mp_obj_t secret_key, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_bip340_sign_obj, mod_trezorcrypto_bip340_sign); +/// def verify_publickey(public_key: bytes) -> bool: +/// """ +/// Verifies whether the public key is valid. +/// Returns True on success. +/// """ +STATIC mp_obj_t mod_trezorcrypto_bip340_verify_publickey(mp_obj_t public_key) { + mp_buffer_info_t pk = {0}; + mp_get_buffer_raise(public_key, &pk, MP_BUFFER_READ); + if (pk.len != 32) { + return mp_const_false; + } + int ret = zkp_bip340_verify_publickey((const uint8_t *)pk.buf); + return mp_obj_new_bool(ret == 0); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorcrypto_bip340_verify_publickey_obj, + mod_trezorcrypto_bip340_verify_publickey); + /// def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool: /// """ /// Uses public key to verify the signature of the digest. @@ -229,6 +246,8 @@ STATIC const mp_rom_map_elem_t mod_trezorcrypto_bip340_globals_table[] = { {MP_ROM_QSTR(MP_QSTR_publickey), MP_ROM_PTR(&mod_trezorcrypto_bip340_publickey_obj)}, {MP_ROM_QSTR(MP_QSTR_sign), MP_ROM_PTR(&mod_trezorcrypto_bip340_sign_obj)}, + {MP_ROM_QSTR(MP_QSTR_verify_publickey), + MP_ROM_PTR(&mod_trezorcrypto_bip340_verify_publickey_obj)}, {MP_ROM_QSTR(MP_QSTR_verify), MP_ROM_PTR(&mod_trezorcrypto_bip340_verify_obj)}, {MP_ROM_QSTR(MP_QSTR_tweak_public_key), diff --git a/core/mocks/generated/trezorcrypto/bip340.pyi b/core/mocks/generated/trezorcrypto/bip340.pyi index c92ab1a33..855cb2002 100644 --- a/core/mocks/generated/trezorcrypto/bip340.pyi +++ b/core/mocks/generated/trezorcrypto/bip340.pyi @@ -25,6 +25,14 @@ def sign( """ +# extmod/modtrezorcrypto/modtrezorcrypto-bip340.h +def verify_publickey(public_key: bytes) -> bool: + """ + Verifies whether the public key is valid. + Returns True on success. + """ + + # extmod/modtrezorcrypto/modtrezorcrypto-bip340.h def verify(public_key: bytes, signature: bytes, digest: bytes) -> bool: """ diff --git a/core/src/apps/bitcoin/common.py b/core/src/apps/bitcoin/common.py index e6b8b86d0..fe95edc86 100644 --- a/core/src/apps/bitcoin/common.py +++ b/core/src/apps/bitcoin/common.py @@ -130,10 +130,13 @@ def encode_bech32_address(prefix: str, witver: int, script: bytes) -> str: def decode_bech32_address(prefix: str, address: str) -> tuple[int, bytes]: witver, raw = bech32.decode(prefix, address) if witver not in _BECH32_WITVERS: - raise wire.ProcessError("Invalid address witness program") + raise wire.DataError("Invalid address witness program") assert witver is not None assert raw is not None - return witver, bytes(raw) + # check that P2TR address encodes a valid BIP340 public key + if witver == 1 and not bip340.verify_publickey(raw): + raise wire.DataError("Invalid Taproot witness program") + return witver, raw def input_is_segwit(txi: TxInput) -> bool: diff --git a/core/src/trezor/crypto/bech32.py b/core/src/trezor/crypto/bech32.py index f41e7cd92..1c299513e 100644 --- a/core/src/trezor/crypto/bech32.py +++ b/core/src/trezor/crypto/bech32.py @@ -154,7 +154,7 @@ def convertbits( return ret -def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]: +def decode(hrp: str, addr: str) -> OptionalTuple2[int, bytes]: """Decode a segwit address.""" hrpgot, data, spec = bech32_decode(addr) # the following two lines are strictly not required @@ -164,7 +164,7 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]: if hrpgot != hrp: return (None, None) try: - decoded = convertbits(data[1:], 5, 8, False) + decoded = bytes(convertbits(data[1:], 5, 8, False)) except ValueError: return (None, None) if not 2 <= len(decoded) <= 40: diff --git a/core/tests/test_trezor.crypto.bech32.py b/core/tests/test_trezor.crypto.bech32.py index 9885b9953..954c0e067 100644 --- a/core/tests/test_trezor.crypto.bech32.py +++ b/core/tests/test_trezor.crypto.bech32.py @@ -27,7 +27,7 @@ from trezor.crypto import bech32 def segwit_scriptpubkey(witver, witprog): """Construct a Segwit scriptPubKey for a given witness program.""" - return bytes([witver + 0x50 if witver else 0, len(witprog)] + witprog) + return bytes([witver + 0x50 if witver else 0, len(witprog)]) + witprog VALID_CHECKSUM = [