diff --git a/src/apps/common/address_type.py b/src/apps/common/address_type.py index cc54c34d2c..ff57d8e479 100644 --- a/src/apps/common/address_type.py +++ b/src/apps/common/address_type.py @@ -9,6 +9,17 @@ def length(address_type): return 4 +def addrtype_bytes(address_type: int): + if address_type <= 0xFF: + return bytes([address_type]) + if address_type <= 0xFFFF: + return bytes([(address_type >> 8), (address_type & 0xFF)]) + if address_type <= 0xFFFFFF: + return bytes([(address_type >> 16), (address_type >> 8), (address_type & 0xFF)]) + # else + return bytes([(address_type >> 24), (address_type >> 16), (address_type >> 8), (address_type & 0xFF)]) + + def check(address_type, raw_address): if address_type <= 0xFF: return address_type == raw_address[0] diff --git a/src/apps/wallet/sign_message.py b/src/apps/wallet/sign_message.py index ef1d820d74..e3064bfd7b 100644 --- a/src/apps/wallet/sign_message.py +++ b/src/apps/wallet/sign_message.py @@ -1,13 +1,14 @@ from trezor import ui from trezor.wire import FailureError from trezor.crypto.curve import secp256k1 -from trezor.messages.InputScriptType import SPENDADDRESS +from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS from trezor.messages.FailureType import ProcessError from trezor.messages.MessageSignature import MessageSignature from trezor.ui.text import Text from apps.common import coins, seed from apps.common.confirm import require_confirm from apps.common.signverify import message_digest, split_message +from apps.wallet.sign_tx.addresses import get_address async def sign_message(ctx, msg): @@ -17,18 +18,24 @@ async def sign_message(ctx, msg): script_type = msg.script_type or 0 coin = coins.by_name(coin_name) - if script_type != SPENDADDRESS: - raise FailureError(ProcessError, 'Unsupported script type') - await confirm_sign_message(ctx, message) node = await seed.derive_node(ctx, address_n) seckey = node.private_key() - address = node.address(coin.address_type) + address = get_address(script_type, coin, node) digest = message_digest(coin, message) signature = secp256k1.sign(seckey, digest) + if script_type == SPENDADDRESS: + pass + elif script_type == SPENDP2SHWITNESS: + signature = bytes([signature[0] + 4]) + signature[1:] + elif script_type == SPENDWITNESS: + signature = bytes([signature[0] + 8]) + signature[1:] + else: + raise FailureError(ProcessError, 'Unsupported script type') + return MessageSignature(address=address, signature=signature) diff --git a/src/apps/wallet/sign_tx/addresses.py b/src/apps/wallet/sign_tx/addresses.py index 3ae0069319..7e601003b1 100644 --- a/src/apps/wallet/sign_tx/addresses.py +++ b/src/apps/wallet/sign_tx/addresses.py @@ -8,6 +8,7 @@ from trezor.messages.CoinType import CoinType from trezor.messages import FailureType from trezor.messages import InputScriptType +from apps.common.address_type import addrtype_bytes from apps.wallet.sign_tx.scripts import * from apps.wallet.sign_tx.multisig import * @@ -97,10 +98,13 @@ def address_multisig_p2wsh(pubkeys: bytes, m: int, hrp: str): return address_p2wsh(witness_script_hash, hrp) +def address_pkh(pubkey: bytes, addrtype: int) -> str: + s = addrtype_bytes(addrtype) + sha256_ripemd160_digest(pubkey) + return base58.encode_check(bytes(s)) + + def address_p2sh(redeem_script_hash: bytes, addrtype: int) -> str: - s = bytearray(21) - s[0] = addrtype - s[1:21] = redeem_script_hash + s = addrtype_bytes(addrtype) + redeem_script_hash return base58.encode_check(bytes(s)) diff --git a/src/apps/wallet/verify_message.py b/src/apps/wallet/verify_message.py index 9da60f5e46..f2ca994a5f 100644 --- a/src/apps/wallet/verify_message.py +++ b/src/apps/wallet/verify_message.py @@ -2,12 +2,15 @@ from trezor import ui, wire from trezor.crypto import base58 from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import ripemd160, sha256 +from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS from trezor.messages.FailureType import ProcessError from trezor.messages.Success import Success from trezor.ui.text import Text from apps.common import address_type, coins from apps.common.confirm import require_confirm from apps.common.signverify import message_digest, split_message +from apps.wallet.sign_tx.addresses import address_pkh, address_p2wpkh_in_p2sh, address_p2wpkh +from apps.wallet.get_address import _split_address async def verify_message(ctx, msg): @@ -17,25 +20,48 @@ async def verify_message(ctx, msg): coin_name = msg.coin_name or 'Bitcoin' coin = coins.by_name(coin_name) - await confirm_verify_message(ctx, message) - digest = message_digest(coin, message) + + script_type = None + recid = signature[0] + if recid >= 27 and recid <= 34: + script_type = SPENDADDRESS # p2pkh + elif recid >= 35 and recid <= 38: + script_type = SPENDP2SHWITNESS # segwit-in-p2sh + signature = bytes([signature[0] - 4]) + signature[1:] + elif recid >= 39 and recid <= 42: + script_type = SPENDWITNESS # native segwit + signature = bytes([signature[0] - 8]) + signature[1:] + else: + raise wire.FailureError(ProcessError, 'Invalid signature') + pubkey = secp256k1.verify_recover(signature, digest) if not pubkey: raise wire.FailureError(ProcessError, 'Invalid signature') - raw_address = base58.decode_check(address) - _, pkh = address_type.split(coin, raw_address) - pkh2 = ripemd160(sha256(pubkey).digest()).digest() - - if pkh != pkh2: + if script_type == SPENDADDRESS: + addr = address_pkh(pubkey, coin.address_type) + elif script_type == SPENDP2SHWITNESS: + addr = address_p2wpkh_in_p2sh(pubkey, coin.address_type_p2sh) + elif script_type == SPENDWITNESS: + addr = address_p2wpkh(pubkey, coin.bech32_prefix) + else: raise wire.FailureError(ProcessError, 'Invalid signature') + if addr != address: + raise wire.FailureError(ProcessError, 'Invalid signature') + + await confirm_verify_message(ctx, address, message) + return Success(message='Message verified') -async def confirm_verify_message(ctx, message): +async def confirm_verify_message(ctx, address, message): + lines = _split_address(address) + content = Text('Confirm address', ui.ICON_DEFAULT, ui.MONO, *lines) + await require_confirm(ctx, content) + message = split_message(message) content = Text('Verify message', ui.ICON_DEFAULT, max_lines=5, *message) await require_confirm(ctx, content)