You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/bitcoin/sign_tx/helpers.py

474 lines
16 KiB

from trezor import utils, wire
from trezor.messages import InputScriptType, OutputScriptType
from trezor.messages.PrevInput import PrevInput
from trezor.messages.PrevOutput import PrevOutput
from trezor.messages.PrevTx import PrevTx
from trezor.messages.RequestType import (
TXEXTRADATA,
TXFINISHED,
TXINPUT,
TXMETA,
TXORIGINPUT,
TXORIGOUTPUT,
TXOUTPUT,
)
from trezor.messages.SignTx import SignTx
from trezor.messages.TxAckInput import TxAckInput
from trezor.messages.TxAckOutput import TxAckOutput
from trezor.messages.TxAckPrevExtraData import TxAckPrevExtraData
from trezor.messages.TxAckPrevInput import TxAckPrevInput
from trezor.messages.TxAckPrevMeta import TxAckPrevMeta
from trezor.messages.TxAckPrevOutput import TxAckPrevOutput
from trezor.messages.TxInput import TxInput
from trezor.messages.TxOutput import TxOutput
from trezor.messages.TxRequest import TxRequest
from apps.common import paths
from apps.common.coininfo import CoinInfo
from .. import common
from ..writers import TX_HASH_SIZE
from . import layout
if False:
from typing import Any, Awaitable
from trezor.messages.SignTx import EnumTypeAmountUnit
# Machine instructions
# ===
class UiConfirm:
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
raise NotImplementedError
class UiConfirmOutput(UiConfirm):
def __init__(
self, output: TxOutput, coin: CoinInfo, amount_unit: EnumTypeAmountUnit
):
self.output = output
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_output(ctx, self.output, self.coin, self.amount_unit)
__eq__ = utils.obj_eq
class UiConfirmDecredSSTXSubmission(UiConfirm):
def __init__(
self, output: TxOutput, coin: CoinInfo, amount_unit: EnumTypeAmountUnit
):
self.output = output
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_decred_sstx_submission(
ctx, self.output, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmReplacement(UiConfirm):
def __init__(self, description: str, txid: bytes):
self.description = description
self.txid = txid
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_replacement(ctx, self.description, self.txid)
__eq__ = utils.obj_eq
class UiConfirmModifyOutput(UiConfirm):
def __init__(
self,
txo: TxOutput,
orig_txo: TxOutput,
coin: CoinInfo,
amount_unit: EnumTypeAmountUnit,
):
self.txo = txo
self.orig_txo = orig_txo
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_modify_output(
ctx, self.txo, self.orig_txo, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmModifyFee(UiConfirm):
def __init__(
self,
user_fee_change: int,
total_fee_new: int,
coin: CoinInfo,
amount_unit: EnumTypeAmountUnit,
):
self.user_fee_change = user_fee_change
self.total_fee_new = total_fee_new
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_modify_fee(
ctx, self.user_fee_change, self.total_fee_new, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmTotal(UiConfirm):
def __init__(
self, spending: int, fee: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit
):
self.spending = spending
self.fee = fee
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_total(
ctx, self.spending, self.fee, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmJointTotal(UiConfirm):
def __init__(
self, spending: int, total: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit
):
self.spending = spending
self.total = total
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_joint_total(
ctx, self.spending, self.total, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmFeeOverThreshold(UiConfirm):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit):
self.fee = fee
self.coin = coin
self.amount_unit = amount_unit
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_feeoverthreshold(
ctx, self.fee, self.coin, self.amount_unit
)
__eq__ = utils.obj_eq
class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int):
self.change_count = change_count
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_change_count_over_threshold(ctx, self.change_count)
__eq__ = utils.obj_eq
class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list):
self.address_n = address_n
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return paths.show_path_warning(ctx, self.address_n)
__eq__ = utils.obj_eq
class UiConfirmNonDefaultLocktime(UiConfirm):
def __init__(self, lock_time: int, lock_time_disabled: bool):
self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled
def confirm_dialog(self, ctx: wire.Context) -> Awaitable[Any]:
return layout.confirm_nondefault_locktime(
ctx, self.lock_time, self.lock_time_disabled
)
__eq__ = utils.obj_eq
def confirm_output(output: TxOutput, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[None]: # type: ignore
return (yield UiConfirmOutput(output, coin, amount_unit))
def confirm_decred_sstx_submission(output: TxOutput, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[None]: # type: ignore
return (yield UiConfirmDecredSSTXSubmission(output, coin, amount_unit))
def confirm_replacement(description: str, txid: bytes) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmReplacement(description, txid))
def confirm_modify_output(txo: TxOutput, orig_txo: TxOutput, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmModifyOutput(txo, orig_txo, coin, amount_unit))
def confirm_modify_fee(user_fee_change: int, total_fee_new: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmModifyFee(user_fee_change, total_fee_new, coin, amount_unit))
def confirm_total(spending: int, fee: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[None]: # type: ignore
return (yield UiConfirmTotal(spending, fee, coin, amount_unit))
def confirm_joint_total(spending: int, total: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmJointTotal(spending, total, coin, amount_unit))
def confirm_feeoverthreshold(fee: int, coin: CoinInfo, amount_unit: EnumTypeAmountUnit) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmFeeOverThreshold(fee, coin, amount_unit))
def confirm_change_count_over_threshold(change_count: int) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmChangeCountOverThreshold(change_count))
def confirm_foreign_address(address_n: list) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmForeignAddress(address_n))
def confirm_nondefault_locktime(lock_time: int, lock_time_disabled: bool) -> Awaitable[Any]: # type: ignore
return (yield UiConfirmNonDefaultLocktime(lock_time, lock_time_disabled))
def request_tx_meta(tx_req: TxRequest, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevTx]: # type: ignore
assert tx_req.details is not None
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevMeta, tx_req
_clear_tx_request(tx_req)
return sanitize_tx_meta(ack.tx, coin)
def request_tx_extra_data( # type: ignore
tx_req: TxRequest, offset: int, size: int, tx_hash: bytes | None = None
) -> Awaitable[bytearray]:
assert tx_req.details is not None
tx_req.request_type = TXEXTRADATA
tx_req.details.extra_data_offset = offset
tx_req.details.extra_data_len = size
tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevExtraData, tx_req
_clear_tx_request(tx_req)
return ack.tx.extra_data_chunk
def request_tx_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxInput]: # type: ignore
assert tx_req.details is not None
if tx_hash:
tx_req.request_type = TXORIGINPUT
tx_req.details.tx_hash = tx_hash
else:
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
ack = yield TxAckInput, tx_req
_clear_tx_request(tx_req)
return sanitize_tx_input(ack.tx.input, coin)
def request_tx_prev_input(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevInput]: # type: ignore
assert tx_req.details is not None
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevInput, tx_req
_clear_tx_request(tx_req)
return sanitize_tx_prev_input(ack.tx.input, coin)
def request_tx_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[TxOutput]: # type: ignore
assert tx_req.details is not None
if tx_hash:
tx_req.request_type = TXORIGOUTPUT
tx_req.details.tx_hash = tx_hash
else:
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
ack = yield TxAckOutput, tx_req
_clear_tx_request(tx_req)
return sanitize_tx_output(ack.tx.output, coin)
def request_tx_prev_output(tx_req: TxRequest, i: int, coin: CoinInfo, tx_hash: bytes | None = None) -> Awaitable[PrevOutput]: # type: ignore
assert tx_req.details is not None
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield TxAckPrevOutput, tx_req
_clear_tx_request(tx_req)
# return sanitize_tx_prev_output(ack.tx, coin) # no sanitize is required
return ack.tx.output
def request_tx_finish(tx_req: TxRequest) -> Awaitable[None]: # type: ignore
tx_req.request_type = TXFINISHED
yield None, tx_req
_clear_tx_request(tx_req)
def _clear_tx_request(tx_req: TxRequest) -> None:
assert tx_req.details is not None
assert tx_req.serialized is not None
assert tx_req.serialized.serialized_tx is not None
tx_req.request_type = None
tx_req.details.request_index = None
tx_req.details.tx_hash = None
tx_req.details.extra_data_len = None
tx_req.details.extra_data_offset = None
tx_req.serialized.signature = None
tx_req.serialized.signature_index = None
# mypy thinks serialized_tx is `bytes`, which doesn't support indexed assignment
tx_req.serialized.serialized_tx[:] = bytes() # type: ignore
# Data sanitizers
# ===
def sanitize_sign_tx(tx: SignTx, coin: CoinInfo) -> SignTx:
if coin.decred or coin.overwintered:
tx.expiry = tx.expiry if tx.expiry is not None else 0
elif tx.expiry:
raise wire.DataError("Expiry not enabled on this coin.")
if coin.timestamp and not tx.timestamp:
raise wire.DataError("Timestamp must be set.")
elif not coin.timestamp and tx.timestamp:
raise wire.DataError("Timestamp not enabled on this coin.")
if coin.overwintered:
if tx.version_group_id is None:
raise wire.DataError("Version group ID must be set.")
if tx.branch_id is None:
raise wire.DataError("Branch ID must be set.")
elif not coin.overwintered:
if tx.version_group_id is not None:
raise wire.DataError("Version group ID not enabled on this coin.")
if tx.branch_id is not None:
raise wire.DataError("Branch ID not enabled on this coin.")
return tx
def sanitize_tx_meta(tx: PrevTx, coin: CoinInfo) -> PrevTx:
if not coin.extra_data and tx.extra_data_len:
raise wire.DataError("Extra data not enabled on this coin.")
if coin.decred or coin.overwintered:
tx.expiry = tx.expiry if tx.expiry is not None else 0
elif tx.expiry:
raise wire.DataError("Expiry not enabled on this coin.")
if coin.timestamp and not tx.timestamp:
raise wire.DataError("Timestamp must be set.")
elif not coin.timestamp and tx.timestamp:
raise wire.DataError("Timestamp not enabled on this coin.")
elif not coin.overwintered:
if tx.version_group_id is not None:
raise wire.DataError("Version group ID not enabled on this coin.")
if tx.branch_id is not None:
raise wire.DataError("Branch ID not enabled on this coin.")
return tx
def sanitize_tx_input(txi: TxInput, coin: CoinInfo) -> TxInput:
if len(txi.prev_hash) != TX_HASH_SIZE:
raise wire.DataError("Provided prev_hash is invalid.")
if txi.multisig and txi.script_type not in common.MULTISIG_INPUT_SCRIPT_TYPES:
raise wire.DataError("Multisig field provided but not expected.")
if not txi.multisig and txi.script_type == InputScriptType.SPENDMULTISIG:
raise wire.DataError("Multisig details required.")
if txi.address_n and txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Input's address_n provided but not expected.")
if not coin.decred and txi.decred_tree is not None:
raise wire.DataError("Decred details provided but Decred coin not specified.")
if txi.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES or txi.witness is not None:
if not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin.")
if txi.commitment_data and not txi.ownership_proof:
raise wire.DataError("commitment_data field provided but not expected.")
if txi.orig_hash and txi.orig_index is None:
raise wire.DataError("Missing orig_index field.")
return txi
def sanitize_tx_prev_input(txi: PrevInput, coin: CoinInfo) -> PrevInput:
if len(txi.prev_hash) != TX_HASH_SIZE:
raise wire.DataError("Provided prev_hash is invalid.")
if not coin.decred and txi.decred_tree is not None:
raise wire.DataError("Decred details provided but Decred coin not specified.")
return txi
def sanitize_tx_output(txo: TxOutput, coin: CoinInfo) -> TxOutput:
if txo.multisig and txo.script_type not in common.MULTISIG_OUTPUT_SCRIPT_TYPES:
raise wire.DataError("Multisig field provided but not expected.")
if not txo.multisig and txo.script_type == OutputScriptType.PAYTOMULTISIG:
raise wire.DataError("Multisig details required.")
if txo.address_n and txo.script_type not in common.CHANGE_OUTPUT_SCRIPT_TYPES:
raise wire.DataError("Output's address_n provided but not expected.")
if txo.amount is None:
raise wire.DataError("Missing amount field.")
if txo.script_type in common.SEGWIT_OUTPUT_SCRIPT_TYPES:
if not coin.segwit:
raise wire.DataError("Segwit not enabled on this coin.")
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output
if txo.op_return_data is None:
raise wire.DataError("OP_RETURN output without op_return_data")
if txo.amount != 0:
raise wire.DataError("OP_RETURN output with non-zero amount")
if txo.address or txo.address_n or txo.multisig:
raise wire.DataError("OP_RETURN output with address or multisig")
else:
if txo.op_return_data:
raise wire.DataError(
"OP RETURN data provided but not OP RETURN script type."
)
if txo.address_n and txo.address:
raise wire.DataError("Both address and address_n provided.")
if not txo.address_n and not txo.address:
raise wire.DataError("Missing address")
if txo.orig_hash and txo.orig_index is None:
raise wire.DataError("Missing orig_index field.")
return txo