mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-08 13:42:41 +00:00
apps.common.signtx: add example sanitization
This commit is contained in:
parent
d6ae782dfc
commit
b145f8f309
@ -11,6 +11,7 @@ from trezor.messages.TxOutputType import TxOutputType
|
|||||||
from trezor.messages.TxOutputBinType import TxOutputBinType
|
from trezor.messages.TxOutputBinType import TxOutputBinType
|
||||||
from trezor.messages.TxInputType import TxInputType
|
from trezor.messages.TxInputType import TxInputType
|
||||||
from trezor.messages.TxRequest import TxRequest
|
from trezor.messages.TxRequest import TxRequest
|
||||||
|
from trezor.messages.TransactionType import TransactionType
|
||||||
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
||||||
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||||
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||||
@ -65,7 +66,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
|||||||
tx_req.details.request_index = None
|
tx_req.details.request_index = None
|
||||||
ack = yield tx_req
|
ack = yield tx_req
|
||||||
tx_req.serialized = None
|
tx_req.serialized = None
|
||||||
return ack.tx
|
return sanitize_tx_meta(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||||
@ -74,7 +75,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|||||||
tx_req.details.tx_hash = tx_hash
|
tx_req.details.tx_hash = tx_hash
|
||||||
ack = yield tx_req
|
ack = yield tx_req
|
||||||
tx_req.serialized = None
|
tx_req.serialized = None
|
||||||
return ack.tx.inputs[0]
|
return sanitize_tx_input(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||||
@ -84,9 +85,9 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|||||||
ack = yield tx_req
|
ack = yield tx_req
|
||||||
tx_req.serialized = None
|
tx_req.serialized = None
|
||||||
if tx_hash is None:
|
if tx_hash is None:
|
||||||
return ack.tx.outputs[0]
|
return sanitize_tx_output(ack.tx)
|
||||||
else:
|
else:
|
||||||
return ack.tx.bin_outputs[0]
|
return sanitize_tx_binoutput(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
def request_tx_finish(tx_req: TxRequest):
|
def request_tx_finish(tx_req: TxRequest):
|
||||||
@ -96,18 +97,50 @@ def request_tx_finish(tx_req: TxRequest):
|
|||||||
tx_req.serialized = None
|
tx_req.serialized = None
|
||||||
|
|
||||||
|
|
||||||
|
# Data sanitizers
|
||||||
|
# ===
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_sign_tx(tx: SignTx) -> SignTx:
|
||||||
|
tx.version = tx.version if tx.version is not None else 1
|
||||||
|
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||||
|
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
|
||||||
|
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
|
||||||
|
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
|
||||||
|
return tx
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
|
||||||
|
tx.version = tx.version if tx.version is not None else 1
|
||||||
|
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||||
|
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
|
||||||
|
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
|
||||||
|
return tx
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
|
||||||
|
txi = tx.inputs[0]
|
||||||
|
txi.script_type = (
|
||||||
|
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
|
||||||
|
return txi
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
|
||||||
|
return tx.outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
|
||||||
|
return tx.bin_outputs[0]
|
||||||
|
|
||||||
|
|
||||||
# Transaction signing
|
# Transaction signing
|
||||||
# ===
|
# ===
|
||||||
|
|
||||||
|
|
||||||
async def sign_tx(tx: SignTx, root):
|
async def sign_tx(tx: SignTx, root):
|
||||||
tx_version = tx.version if tx.version is not None else 1
|
|
||||||
tx_lock_time = tx.lock_time or 0
|
|
||||||
tx_inputs_count = tx.inputs_count or 0
|
|
||||||
tx_outputs_count = tx.outputs_count or 0
|
|
||||||
coin_name = tx.coin_name or 'Bitcoin'
|
|
||||||
|
|
||||||
coin = coins.by_name(coin_name)
|
tx = sanitize_sign_tx(tx)
|
||||||
|
coin = coins.by_name(tx.coin_name)
|
||||||
|
|
||||||
# Phase 1
|
# Phase 1
|
||||||
# - check inputs, previous transactions, and outputs
|
# - check inputs, previous transactions, and outputs
|
||||||
@ -127,14 +160,14 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
tx_req = TxRequest()
|
tx_req = TxRequest()
|
||||||
tx_req.details = TxRequestDetailsType()
|
tx_req.details = TxRequestDetailsType()
|
||||||
|
|
||||||
for i in range(tx_inputs_count):
|
for i in range(tx.inputs_count):
|
||||||
# STAGE_REQUEST_1_INPUT
|
# STAGE_REQUEST_1_INPUT
|
||||||
txi = await request_tx_input(tx_req, i)
|
txi = await request_tx_input(tx_req, i)
|
||||||
write_tx_input_check(h_first, txi)
|
write_tx_input_check(h_first, txi)
|
||||||
total_in += await get_prevtx_output_value(
|
total_in += await get_prevtx_output_value(
|
||||||
tx_req, txi.prev_hash, txi.prev_index)
|
tx_req, txi.prev_hash, txi.prev_index)
|
||||||
|
|
||||||
for o in range(tx_outputs_count):
|
for o in range(tx.outputs_count):
|
||||||
# STAGE_REQUEST_3_OUTPUT
|
# STAGE_REQUEST_3_OUTPUT
|
||||||
txo = await request_tx_output(tx_req, o)
|
txo = await request_tx_output(tx_req, o)
|
||||||
if output_is_change(txo):
|
if output_is_change(txo):
|
||||||
@ -157,7 +190,7 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
raise SigningError(FailureType.NotEnoughFunds,
|
raise SigningError(FailureType.NotEnoughFunds,
|
||||||
'Not enough funds')
|
'Not enough funds')
|
||||||
|
|
||||||
if fee > coin.maxfee_kb * ((estimate_tx_size(tx_inputs_count, tx_outputs_count) + 999) // 1000):
|
if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000):
|
||||||
if not await confirm_feeoverthreshold(fee, coin):
|
if not await confirm_feeoverthreshold(fee, coin):
|
||||||
raise SigningError(FailureType.ActionCancelled,
|
raise SigningError(FailureType.ActionCancelled,
|
||||||
'Signing cancelled')
|
'Signing cancelled')
|
||||||
@ -172,7 +205,7 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
|
|
||||||
tx_ser = TxRequestSerializedType()
|
tx_ser = TxRequestSerializedType()
|
||||||
|
|
||||||
for i_sign in range(tx_inputs_count):
|
for i_sign in range(tx.inputs_count):
|
||||||
# hash of what we are signing with this input
|
# hash of what we are signing with this input
|
||||||
h_sign = HashWriter(sha256)
|
h_sign = HashWriter(sha256)
|
||||||
# same as h_first, checked at the end of this iteration
|
# same as h_first, checked at the end of this iteration
|
||||||
@ -182,11 +215,11 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
key_sign = None
|
key_sign = None
|
||||||
key_sign_pub = None
|
key_sign_pub = None
|
||||||
|
|
||||||
write_uint32(h_sign, tx_version)
|
write_uint32(h_sign, tx.version)
|
||||||
|
|
||||||
write_varint(h_sign, tx_inputs_count)
|
write_varint(h_sign, tx.inputs_count)
|
||||||
|
|
||||||
for i in range(tx_inputs_count):
|
for i in range(tx.inputs_count):
|
||||||
# STAGE_REQUEST_4_INPUT
|
# STAGE_REQUEST_4_INPUT
|
||||||
txi = await request_tx_input(tx_req, i)
|
txi = await request_tx_input(tx_req, i)
|
||||||
write_tx_input_check(h_second, txi)
|
write_tx_input_check(h_second, txi)
|
||||||
@ -199,9 +232,9 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
txi.script_sig = bytes()
|
txi.script_sig = bytes()
|
||||||
write_tx_input(h_sign, txi)
|
write_tx_input(h_sign, txi)
|
||||||
|
|
||||||
write_varint(h_sign, tx_outputs_count)
|
write_varint(h_sign, tx.outputs_count)
|
||||||
|
|
||||||
for o in range(tx_outputs_count):
|
for o in range(tx.outputs_count):
|
||||||
# STAGE_REQUEST_4_OUTPUT
|
# STAGE_REQUEST_4_OUTPUT
|
||||||
txo = await request_tx_output(tx_req, o)
|
txo = await request_tx_output(tx_req, o)
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
@ -209,7 +242,7 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
write_tx_output(h_second, txo_bin)
|
write_tx_output(h_second, txo_bin)
|
||||||
write_tx_output(h_sign, txo_bin)
|
write_tx_output(h_sign, txo_bin)
|
||||||
|
|
||||||
write_uint32(h_sign, tx_lock_time)
|
write_uint32(h_sign, tx.lock_time)
|
||||||
|
|
||||||
write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type
|
write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type
|
||||||
|
|
||||||
@ -229,14 +262,14 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
w_txi_sign = bytearray_with_cap(
|
w_txi_sign = bytearray_with_cap(
|
||||||
len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4)
|
len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4)
|
||||||
if i_sign == 0: # serializing first input => prepend tx version and inputs count
|
if i_sign == 0: # serializing first input => prepend tx version and inputs count
|
||||||
write_uint32(w_txi_sign, tx_version)
|
write_uint32(w_txi_sign, tx.version)
|
||||||
write_varint(w_txi_sign, tx_inputs_count)
|
write_varint(w_txi_sign, tx.inputs_count)
|
||||||
write_tx_input(w_txi_sign, txi_sign)
|
write_tx_input(w_txi_sign, txi_sign)
|
||||||
tx_ser.serialized_tx = w_txi_sign
|
tx_ser.serialized_tx = w_txi_sign
|
||||||
|
|
||||||
tx_req.serialized = tx_ser
|
tx_req.serialized = tx_ser
|
||||||
|
|
||||||
for o in range(tx_outputs_count):
|
for o in range(tx.outputs_count):
|
||||||
# STAGE_REQUEST_5_OUTPUT
|
# STAGE_REQUEST_5_OUTPUT
|
||||||
txo = await request_tx_output(tx_req, o)
|
txo = await request_tx_output(tx_req, o)
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
@ -246,10 +279,10 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
w_txo_bin = bytearray_with_cap(
|
w_txo_bin = bytearray_with_cap(
|
||||||
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
|
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
|
||||||
if o == 0: # serializing first output => prepend outputs count
|
if o == 0: # serializing first output => prepend outputs count
|
||||||
write_varint(w_txo_bin, tx_outputs_count)
|
write_varint(w_txo_bin, tx.outputs_count)
|
||||||
write_tx_output(w_txo_bin, txo_bin)
|
write_tx_output(w_txo_bin, txo_bin)
|
||||||
if o == tx_outputs_count - 1: # serializing last output => append tx lock_time
|
if o == tx.outputs_count - 1: # serializing last output => append tx lock_time
|
||||||
write_uint32(w_txo_bin, tx_lock_time)
|
write_uint32(w_txo_bin, tx.lock_time)
|
||||||
tx_ser.signature_index = None
|
tx_ser.signature_index = None
|
||||||
tx_ser.signature = None
|
tx_ser.signature = None
|
||||||
tx_ser.serialized_tx = w_txo_bin
|
tx_ser.serialized_tx = w_txo_bin
|
||||||
@ -265,32 +298,26 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
|
|||||||
# STAGE_REQUEST_2_PREV_META
|
# STAGE_REQUEST_2_PREV_META
|
||||||
tx = await request_tx_meta(tx_req, prev_hash)
|
tx = await request_tx_meta(tx_req, prev_hash)
|
||||||
|
|
||||||
tx_version = tx.version if tx.version is not None else 1
|
|
||||||
tx_lock_time = tx.lock_time or 0
|
|
||||||
tx_inputs_count = tx.inputs_cnt or 0
|
|
||||||
tx_outputs_count = tx.outputs_cnt or 0
|
|
||||||
|
|
||||||
txh = HashWriter(sha256)
|
txh = HashWriter(sha256)
|
||||||
|
|
||||||
write_uint32(txh, tx_version)
|
write_uint32(txh, tx.version)
|
||||||
|
write_varint(txh, tx.inputs_cnt)
|
||||||
|
|
||||||
write_varint(txh, tx_inputs_count)
|
for i in range(tx.inputs_cnt):
|
||||||
|
|
||||||
for i in range(tx_inputs_count):
|
|
||||||
# STAGE_REQUEST_2_PREV_INPUT
|
# STAGE_REQUEST_2_PREV_INPUT
|
||||||
txi = await request_tx_input(tx_req, i, prev_hash)
|
txi = await request_tx_input(tx_req, i, prev_hash)
|
||||||
write_tx_input(txh, txi)
|
write_tx_input(txh, txi)
|
||||||
|
|
||||||
write_varint(txh, tx_outputs_count)
|
write_varint(txh, tx.outputs_cnt)
|
||||||
|
|
||||||
for o in range(tx_outputs_count):
|
for o in range(tx.outputs_cnt):
|
||||||
# STAGE_REQUEST_2_PREV_OUTPUT
|
# STAGE_REQUEST_2_PREV_OUTPUT
|
||||||
txo_bin = await request_tx_output(tx_req, o, prev_hash)
|
txo_bin = await request_tx_output(tx_req, o, prev_hash)
|
||||||
write_tx_output(txh, txo_bin)
|
write_tx_output(txh, txo_bin)
|
||||||
if o == prev_index:
|
if o == prev_index:
|
||||||
total_out += txo_bin.amount
|
total_out += txo_bin.amount
|
||||||
|
|
||||||
write_uint32(txh, tx_lock_time)
|
write_uint32(txh, tx.lock_time)
|
||||||
|
|
||||||
if get_tx_hash(txh, True, True) != prev_hash:
|
if get_tx_hash(txh, True, True) != prev_hash:
|
||||||
raise SigningError(FailureType.Other,
|
raise SigningError(FailureType.Other,
|
||||||
@ -366,9 +393,7 @@ def output_is_change(o: TxOutputType) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
|
def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
|
||||||
script_type = i.script_type if i.script_type is not None else InputScriptType.SPENDADDRESS
|
if i.script_type == InputScriptType.SPENDADDRESS:
|
||||||
|
|
||||||
if script_type == InputScriptType.SPENDADDRESS:
|
|
||||||
if signature is None:
|
if signature is None:
|
||||||
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
|
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user