mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-01 18:30:56 +00:00
apps.common.signtx: add example sanitization
This commit is contained in:
parent
d6ae782dfc
commit
b145f8f309
@ -11,6 +11,7 @@ from trezor.messages.TxOutputType import TxOutputType
|
||||
from trezor.messages.TxOutputBinType import TxOutputBinType
|
||||
from trezor.messages.TxInputType import TxInputType
|
||||
from trezor.messages.TxRequest import TxRequest
|
||||
from trezor.messages.TransactionType import TransactionType
|
||||
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
||||
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||
@ -65,7 +66,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
||||
tx_req.details.request_index = None
|
||||
ack = yield tx_req
|
||||
tx_req.serialized = None
|
||||
return ack.tx
|
||||
return sanitize_tx_meta(ack.tx)
|
||||
|
||||
|
||||
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||
@ -74,7 +75,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||
tx_req.details.tx_hash = tx_hash
|
||||
ack = yield tx_req
|
||||
tx_req.serialized = None
|
||||
return ack.tx.inputs[0]
|
||||
return sanitize_tx_input(ack.tx)
|
||||
|
||||
|
||||
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||
@ -84,9 +85,9 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||
ack = yield tx_req
|
||||
tx_req.serialized = None
|
||||
if tx_hash is None:
|
||||
return ack.tx.outputs[0]
|
||||
return sanitize_tx_output(ack.tx)
|
||||
else:
|
||||
return ack.tx.bin_outputs[0]
|
||||
return sanitize_tx_binoutput(ack.tx)
|
||||
|
||||
|
||||
def request_tx_finish(tx_req: TxRequest):
|
||||
@ -96,18 +97,50 @@ def request_tx_finish(tx_req: TxRequest):
|
||||
tx_req.serialized = None
|
||||
|
||||
|
||||
# Data sanitizers
|
||||
# ===
|
||||
|
||||
|
||||
def sanitize_sign_tx(tx: SignTx) -> SignTx:
|
||||
tx.version = tx.version if tx.version is not None else 1
|
||||
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
|
||||
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
|
||||
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
|
||||
return tx
|
||||
|
||||
|
||||
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
|
||||
tx.version = tx.version if tx.version is not None else 1
|
||||
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
|
||||
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
|
||||
return tx
|
||||
|
||||
|
||||
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
|
||||
txi = tx.inputs[0]
|
||||
txi.script_type = (
|
||||
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
|
||||
return txi
|
||||
|
||||
|
||||
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
|
||||
return tx.outputs[0]
|
||||
|
||||
|
||||
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
|
||||
return tx.bin_outputs[0]
|
||||
|
||||
|
||||
# Transaction signing
|
||||
# ===
|
||||
|
||||
|
||||
async def sign_tx(tx: SignTx, root):
|
||||
tx_version = tx.version if tx.version is not None else 1
|
||||
tx_lock_time = tx.lock_time or 0
|
||||
tx_inputs_count = tx.inputs_count or 0
|
||||
tx_outputs_count = tx.outputs_count or 0
|
||||
coin_name = tx.coin_name or 'Bitcoin'
|
||||
|
||||
coin = coins.by_name(coin_name)
|
||||
tx = sanitize_sign_tx(tx)
|
||||
coin = coins.by_name(tx.coin_name)
|
||||
|
||||
# Phase 1
|
||||
# - check inputs, previous transactions, and outputs
|
||||
@ -127,14 +160,14 @@ async def sign_tx(tx: SignTx, root):
|
||||
tx_req = TxRequest()
|
||||
tx_req.details = TxRequestDetailsType()
|
||||
|
||||
for i in range(tx_inputs_count):
|
||||
for i in range(tx.inputs_count):
|
||||
# STAGE_REQUEST_1_INPUT
|
||||
txi = await request_tx_input(tx_req, i)
|
||||
write_tx_input_check(h_first, txi)
|
||||
total_in += await get_prevtx_output_value(
|
||||
tx_req, txi.prev_hash, txi.prev_index)
|
||||
|
||||
for o in range(tx_outputs_count):
|
||||
for o in range(tx.outputs_count):
|
||||
# STAGE_REQUEST_3_OUTPUT
|
||||
txo = await request_tx_output(tx_req, o)
|
||||
if output_is_change(txo):
|
||||
@ -157,7 +190,7 @@ async def sign_tx(tx: SignTx, root):
|
||||
raise SigningError(FailureType.NotEnoughFunds,
|
||||
'Not enough funds')
|
||||
|
||||
if fee > coin.maxfee_kb * ((estimate_tx_size(tx_inputs_count, tx_outputs_count) + 999) // 1000):
|
||||
if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000):
|
||||
if not await confirm_feeoverthreshold(fee, coin):
|
||||
raise SigningError(FailureType.ActionCancelled,
|
||||
'Signing cancelled')
|
||||
@ -172,7 +205,7 @@ async def sign_tx(tx: SignTx, root):
|
||||
|
||||
tx_ser = TxRequestSerializedType()
|
||||
|
||||
for i_sign in range(tx_inputs_count):
|
||||
for i_sign in range(tx.inputs_count):
|
||||
# hash of what we are signing with this input
|
||||
h_sign = HashWriter(sha256)
|
||||
# same as h_first, checked at the end of this iteration
|
||||
@ -182,11 +215,11 @@ async def sign_tx(tx: SignTx, root):
|
||||
key_sign = None
|
||||
key_sign_pub = None
|
||||
|
||||
write_uint32(h_sign, tx_version)
|
||||
write_uint32(h_sign, tx.version)
|
||||
|
||||
write_varint(h_sign, tx_inputs_count)
|
||||
write_varint(h_sign, tx.inputs_count)
|
||||
|
||||
for i in range(tx_inputs_count):
|
||||
for i in range(tx.inputs_count):
|
||||
# STAGE_REQUEST_4_INPUT
|
||||
txi = await request_tx_input(tx_req, i)
|
||||
write_tx_input_check(h_second, txi)
|
||||
@ -199,9 +232,9 @@ async def sign_tx(tx: SignTx, root):
|
||||
txi.script_sig = bytes()
|
||||
write_tx_input(h_sign, txi)
|
||||
|
||||
write_varint(h_sign, tx_outputs_count)
|
||||
write_varint(h_sign, tx.outputs_count)
|
||||
|
||||
for o in range(tx_outputs_count):
|
||||
for o in range(tx.outputs_count):
|
||||
# STAGE_REQUEST_4_OUTPUT
|
||||
txo = await request_tx_output(tx_req, o)
|
||||
txo_bin.amount = txo.amount
|
||||
@ -209,7 +242,7 @@ async def sign_tx(tx: SignTx, root):
|
||||
write_tx_output(h_second, txo_bin)
|
||||
write_tx_output(h_sign, txo_bin)
|
||||
|
||||
write_uint32(h_sign, tx_lock_time)
|
||||
write_uint32(h_sign, tx.lock_time)
|
||||
|
||||
write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type
|
||||
|
||||
@ -229,14 +262,14 @@ async def sign_tx(tx: SignTx, root):
|
||||
w_txi_sign = bytearray_with_cap(
|
||||
len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4)
|
||||
if i_sign == 0: # serializing first input => prepend tx version and inputs count
|
||||
write_uint32(w_txi_sign, tx_version)
|
||||
write_varint(w_txi_sign, tx_inputs_count)
|
||||
write_uint32(w_txi_sign, tx.version)
|
||||
write_varint(w_txi_sign, tx.inputs_count)
|
||||
write_tx_input(w_txi_sign, txi_sign)
|
||||
tx_ser.serialized_tx = w_txi_sign
|
||||
|
||||
tx_req.serialized = tx_ser
|
||||
|
||||
for o in range(tx_outputs_count):
|
||||
for o in range(tx.outputs_count):
|
||||
# STAGE_REQUEST_5_OUTPUT
|
||||
txo = await request_tx_output(tx_req, o)
|
||||
txo_bin.amount = txo.amount
|
||||
@ -246,10 +279,10 @@ async def sign_tx(tx: SignTx, root):
|
||||
w_txo_bin = bytearray_with_cap(
|
||||
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
|
||||
if o == 0: # serializing first output => prepend outputs count
|
||||
write_varint(w_txo_bin, tx_outputs_count)
|
||||
write_varint(w_txo_bin, tx.outputs_count)
|
||||
write_tx_output(w_txo_bin, txo_bin)
|
||||
if o == tx_outputs_count - 1: # serializing last output => append tx lock_time
|
||||
write_uint32(w_txo_bin, tx_lock_time)
|
||||
if o == tx.outputs_count - 1: # serializing last output => append tx lock_time
|
||||
write_uint32(w_txo_bin, tx.lock_time)
|
||||
tx_ser.signature_index = None
|
||||
tx_ser.signature = None
|
||||
tx_ser.serialized_tx = w_txo_bin
|
||||
@ -265,32 +298,26 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
|
||||
# STAGE_REQUEST_2_PREV_META
|
||||
tx = await request_tx_meta(tx_req, prev_hash)
|
||||
|
||||
tx_version = tx.version if tx.version is not None else 1
|
||||
tx_lock_time = tx.lock_time or 0
|
||||
tx_inputs_count = tx.inputs_cnt or 0
|
||||
tx_outputs_count = tx.outputs_cnt or 0
|
||||
|
||||
txh = HashWriter(sha256)
|
||||
|
||||
write_uint32(txh, tx_version)
|
||||
write_uint32(txh, tx.version)
|
||||
write_varint(txh, tx.inputs_cnt)
|
||||
|
||||
write_varint(txh, tx_inputs_count)
|
||||
|
||||
for i in range(tx_inputs_count):
|
||||
for i in range(tx.inputs_cnt):
|
||||
# STAGE_REQUEST_2_PREV_INPUT
|
||||
txi = await request_tx_input(tx_req, i, prev_hash)
|
||||
write_tx_input(txh, txi)
|
||||
|
||||
write_varint(txh, tx_outputs_count)
|
||||
write_varint(txh, tx.outputs_cnt)
|
||||
|
||||
for o in range(tx_outputs_count):
|
||||
for o in range(tx.outputs_cnt):
|
||||
# STAGE_REQUEST_2_PREV_OUTPUT
|
||||
txo_bin = await request_tx_output(tx_req, o, prev_hash)
|
||||
write_tx_output(txh, txo_bin)
|
||||
if o == prev_index:
|
||||
total_out += txo_bin.amount
|
||||
|
||||
write_uint32(txh, tx_lock_time)
|
||||
write_uint32(txh, tx.lock_time)
|
||||
|
||||
if get_tx_hash(txh, True, True) != prev_hash:
|
||||
raise SigningError(FailureType.Other,
|
||||
@ -366,9 +393,7 @@ def output_is_change(o: TxOutputType) -> bool:
|
||||
|
||||
|
||||
def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
|
||||
script_type = i.script_type if i.script_type is not None else InputScriptType.SPENDADDRESS
|
||||
|
||||
if script_type == InputScriptType.SPENDADDRESS:
|
||||
if i.script_type == InputScriptType.SPENDADDRESS:
|
||||
if signature is None:
|
||||
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user