apps.common.signtx: add example sanitization

pull/25/head
Jan Pochyla 8 years ago
parent d6ae782dfc
commit b145f8f309

@ -11,6 +11,7 @@ from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
@ -65,7 +66,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.details.request_index = None
ack = yield tx_req
tx_req.serialized = None
return ack.tx
return sanitize_tx_meta(ack.tx)
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
@ -74,7 +75,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
return ack.tx.inputs[0]
return sanitize_tx_input(ack.tx)
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
@ -84,9 +85,9 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return ack.tx.outputs[0]
return sanitize_tx_output(ack.tx)
else:
return ack.tx.bin_outputs[0]
return sanitize_tx_binoutput(ack.tx)
def request_tx_finish(tx_req: TxRequest):
@ -96,18 +97,50 @@ def request_tx_finish(tx_req: TxRequest):
tx_req.serialized = None
# Data sanitizers
# ===
def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
return tx
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
return tx
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
txi = tx.inputs[0]
txi.script_type = (
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
return txi
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
return tx.outputs[0]
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
return tx.bin_outputs[0]
# Transaction signing
# ===
async def sign_tx(tx: SignTx, root):
tx_version = tx.version if tx.version is not None else 1
tx_lock_time = tx.lock_time or 0
tx_inputs_count = tx.inputs_count or 0
tx_outputs_count = tx.outputs_count or 0
coin_name = tx.coin_name or 'Bitcoin'
coin = coins.by_name(coin_name)
tx = sanitize_sign_tx(tx)
coin = coins.by_name(tx.coin_name)
# Phase 1
# - check inputs, previous transactions, and outputs
@ -127,14 +160,14 @@ async def sign_tx(tx: SignTx, root):
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
for i in range(tx_inputs_count):
for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input_check(h_first, txi)
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
for o in range(tx_outputs_count):
for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT
txo = await request_tx_output(tx_req, o)
if output_is_change(txo):
@ -157,7 +190,7 @@ async def sign_tx(tx: SignTx, root):
raise SigningError(FailureType.NotEnoughFunds,
'Not enough funds')
if fee > coin.maxfee_kb * ((estimate_tx_size(tx_inputs_count, tx_outputs_count) + 999) // 1000):
if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000):
if not await confirm_feeoverthreshold(fee, coin):
raise SigningError(FailureType.ActionCancelled,
'Signing cancelled')
@ -172,7 +205,7 @@ async def sign_tx(tx: SignTx, root):
tx_ser = TxRequestSerializedType()
for i_sign in range(tx_inputs_count):
for i_sign in range(tx.inputs_count):
# hash of what we are signing with this input
h_sign = HashWriter(sha256)
# same as h_first, checked at the end of this iteration
@ -182,11 +215,11 @@ async def sign_tx(tx: SignTx, root):
key_sign = None
key_sign_pub = None
write_uint32(h_sign, tx_version)
write_uint32(h_sign, tx.version)
write_varint(h_sign, tx_inputs_count)
write_varint(h_sign, tx.inputs_count)
for i in range(tx_inputs_count):
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input_check(h_second, txi)
@ -199,9 +232,9 @@ async def sign_tx(tx: SignTx, root):
txi.script_sig = bytes()
write_tx_input(h_sign, txi)
write_varint(h_sign, tx_outputs_count)
write_varint(h_sign, tx.outputs_count)
for o in range(tx_outputs_count):
for o in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
@ -209,7 +242,7 @@ async def sign_tx(tx: SignTx, root):
write_tx_output(h_second, txo_bin)
write_tx_output(h_sign, txo_bin)
write_uint32(h_sign, tx_lock_time)
write_uint32(h_sign, tx.lock_time)
write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type
@ -229,14 +262,14 @@ async def sign_tx(tx: SignTx, root):
w_txi_sign = bytearray_with_cap(
len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend tx version and inputs count
write_uint32(w_txi_sign, tx_version)
write_varint(w_txi_sign, tx_inputs_count)
write_uint32(w_txi_sign, tx.version)
write_varint(w_txi_sign, tx.inputs_count)
write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
tx_req.serialized = tx_ser
for o in range(tx_outputs_count):
for o in range(tx.outputs_count):
# STAGE_REQUEST_5_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
@ -246,10 +279,10 @@ async def sign_tx(tx: SignTx, root):
w_txo_bin = bytearray_with_cap(
5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx_outputs_count)
write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin)
if o == tx_outputs_count - 1: # serializing last output => append tx lock_time
write_uint32(w_txo_bin, tx_lock_time)
if o == tx.outputs_count - 1: # serializing last output => append tx lock_time
write_uint32(w_txo_bin, tx.lock_time)
tx_ser.signature_index = None
tx_ser.signature = None
tx_ser.serialized_tx = w_txo_bin
@ -265,32 +298,26 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
# STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(tx_req, prev_hash)
tx_version = tx.version if tx.version is not None else 1
tx_lock_time = tx.lock_time or 0
tx_inputs_count = tx.inputs_cnt or 0
tx_outputs_count = tx.outputs_cnt or 0
txh = HashWriter(sha256)
write_uint32(txh, tx_version)
write_uint32(txh, tx.version)
write_varint(txh, tx.inputs_cnt)
write_varint(txh, tx_inputs_count)
for i in range(tx_inputs_count):
for i in range(tx.inputs_cnt):
# STAGE_REQUEST_2_PREV_INPUT
txi = await request_tx_input(tx_req, i, prev_hash)
write_tx_input(txh, txi)
write_varint(txh, tx_outputs_count)
write_varint(txh, tx.outputs_cnt)
for o in range(tx_outputs_count):
for o in range(tx.outputs_cnt):
# STAGE_REQUEST_2_PREV_OUTPUT
txo_bin = await request_tx_output(tx_req, o, prev_hash)
write_tx_output(txh, txo_bin)
if o == prev_index:
total_out += txo_bin.amount
write_uint32(txh, tx_lock_time)
write_uint32(txh, tx.lock_time)
if get_tx_hash(txh, True, True) != prev_hash:
raise SigningError(FailureType.Other,
@ -366,9 +393,7 @@ def output_is_change(o: TxOutputType) -> bool:
def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
script_type = i.script_type if i.script_type is not None else InputScriptType.SPENDADDRESS
if script_type == InputScriptType.SPENDADDRESS:
if i.script_type == InputScriptType.SPENDADDRESS:
if signature is None:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else:

Loading…
Cancel
Save