From b145f8f309d5fe6fc89d0c39f4abda9c50516f6b Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Thu, 24 Nov 2016 13:58:23 +0100 Subject: [PATCH] apps.common.signtx: add example sanitization --- src/apps/common/signtx.py | 107 +++++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 4b44c0fa9..6739cec03 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -11,6 +11,7 @@ from trezor.messages.TxOutputType import TxOutputType from trezor.messages.TxOutputBinType import TxOutputBinType from trezor.messages.TxInputType import TxInputType from trezor.messages.TxRequest import TxRequest +from trezor.messages.TransactionType import TransactionType from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType @@ -65,7 +66,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): tx_req.details.request_index = None ack = yield tx_req tx_req.serialized = None - return ack.tx + return sanitize_tx_meta(ack.tx) def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): @@ -74,7 +75,7 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): tx_req.details.tx_hash = tx_hash ack = yield tx_req tx_req.serialized = None - return ack.tx.inputs[0] + return sanitize_tx_input(ack.tx) def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None): @@ -84,9 +85,9 @@ def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None): ack = yield tx_req tx_req.serialized = None if tx_hash is None: - return ack.tx.outputs[0] + return sanitize_tx_output(ack.tx) else: - return ack.tx.bin_outputs[0] + return sanitize_tx_binoutput(ack.tx) def request_tx_finish(tx_req: TxRequest): @@ -96,18 +97,50 @@ def request_tx_finish(tx_req: TxRequest): tx_req.serialized = None +# Data sanitizers +# === + + +def sanitize_sign_tx(tx: SignTx) -> SignTx: + tx.version = tx.version if tx.version is not None else 1 + tx.lock_time = tx.lock_time if tx.lock_time is not None else 0 + tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0 + tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0 + tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin' + return tx + + +def sanitize_tx_meta(tx: TransactionType) -> TransactionType: + tx.version = tx.version if tx.version is not None else 1 + tx.lock_time = tx.lock_time if tx.lock_time is not None else 0 + tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0 + tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0 + return tx + + +def sanitize_tx_input(tx: TransactionType) -> TxInputType: + txi = tx.inputs[0] + txi.script_type = ( + txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS) + return txi + + +def sanitize_tx_output(tx: TransactionType) -> TxOutputType: + return tx.outputs[0] + + +def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType: + return tx.bin_outputs[0] + + # Transaction signing # === async def sign_tx(tx: SignTx, root): - tx_version = tx.version if tx.version is not None else 1 - tx_lock_time = tx.lock_time or 0 - tx_inputs_count = tx.inputs_count or 0 - tx_outputs_count = tx.outputs_count or 0 - coin_name = tx.coin_name or 'Bitcoin' - coin = coins.by_name(coin_name) + tx = sanitize_sign_tx(tx) + coin = coins.by_name(tx.coin_name) # Phase 1 # - check inputs, previous transactions, and outputs @@ -127,14 +160,14 @@ async def sign_tx(tx: SignTx, root): tx_req = TxRequest() tx_req.details = TxRequestDetailsType() - for i in range(tx_inputs_count): + for i in range(tx.inputs_count): # STAGE_REQUEST_1_INPUT txi = await request_tx_input(tx_req, i) write_tx_input_check(h_first, txi) total_in += await get_prevtx_output_value( tx_req, txi.prev_hash, txi.prev_index) - for o in range(tx_outputs_count): + for o in range(tx.outputs_count): # STAGE_REQUEST_3_OUTPUT txo = await request_tx_output(tx_req, o) if output_is_change(txo): @@ -157,7 +190,7 @@ async def sign_tx(tx: SignTx, root): raise SigningError(FailureType.NotEnoughFunds, 'Not enough funds') - if fee > coin.maxfee_kb * ((estimate_tx_size(tx_inputs_count, tx_outputs_count) + 999) // 1000): + if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000): if not await confirm_feeoverthreshold(fee, coin): raise SigningError(FailureType.ActionCancelled, 'Signing cancelled') @@ -172,7 +205,7 @@ async def sign_tx(tx: SignTx, root): tx_ser = TxRequestSerializedType() - for i_sign in range(tx_inputs_count): + for i_sign in range(tx.inputs_count): # hash of what we are signing with this input h_sign = HashWriter(sha256) # same as h_first, checked at the end of this iteration @@ -182,11 +215,11 @@ async def sign_tx(tx: SignTx, root): key_sign = None key_sign_pub = None - write_uint32(h_sign, tx_version) + write_uint32(h_sign, tx.version) - write_varint(h_sign, tx_inputs_count) + write_varint(h_sign, tx.inputs_count) - for i in range(tx_inputs_count): + for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT txi = await request_tx_input(tx_req, i) write_tx_input_check(h_second, txi) @@ -199,9 +232,9 @@ async def sign_tx(tx: SignTx, root): txi.script_sig = bytes() write_tx_input(h_sign, txi) - write_varint(h_sign, tx_outputs_count) + write_varint(h_sign, tx.outputs_count) - for o in range(tx_outputs_count): + for o in range(tx.outputs_count): # STAGE_REQUEST_4_OUTPUT txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount @@ -209,7 +242,7 @@ async def sign_tx(tx: SignTx, root): write_tx_output(h_second, txo_bin) write_tx_output(h_sign, txo_bin) - write_uint32(h_sign, tx_lock_time) + write_uint32(h_sign, tx.lock_time) write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type @@ -229,14 +262,14 @@ async def sign_tx(tx: SignTx, root): w_txi_sign = bytearray_with_cap( len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4) if i_sign == 0: # serializing first input => prepend tx version and inputs count - write_uint32(w_txi_sign, tx_version) - write_varint(w_txi_sign, tx_inputs_count) + write_uint32(w_txi_sign, tx.version) + write_varint(w_txi_sign, tx.inputs_count) write_tx_input(w_txi_sign, txi_sign) tx_ser.serialized_tx = w_txi_sign tx_req.serialized = tx_ser - for o in range(tx_outputs_count): + for o in range(tx.outputs_count): # STAGE_REQUEST_5_OUTPUT txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount @@ -246,10 +279,10 @@ async def sign_tx(tx: SignTx, root): w_txo_bin = bytearray_with_cap( 5 + 8 + 5 + len(txo_bin.script_pubkey) + 4) if o == 0: # serializing first output => prepend outputs count - write_varint(w_txo_bin, tx_outputs_count) + write_varint(w_txo_bin, tx.outputs_count) write_tx_output(w_txo_bin, txo_bin) - if o == tx_outputs_count - 1: # serializing last output => append tx lock_time - write_uint32(w_txo_bin, tx_lock_time) + if o == tx.outputs_count - 1: # serializing last output => append tx lock_time + write_uint32(w_txo_bin, tx.lock_time) tx_ser.signature_index = None tx_ser.signature = None tx_ser.serialized_tx = w_txo_bin @@ -265,32 +298,26 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde # STAGE_REQUEST_2_PREV_META tx = await request_tx_meta(tx_req, prev_hash) - tx_version = tx.version if tx.version is not None else 1 - tx_lock_time = tx.lock_time or 0 - tx_inputs_count = tx.inputs_cnt or 0 - tx_outputs_count = tx.outputs_cnt or 0 - txh = HashWriter(sha256) - write_uint32(txh, tx_version) + write_uint32(txh, tx.version) + write_varint(txh, tx.inputs_cnt) - write_varint(txh, tx_inputs_count) - - for i in range(tx_inputs_count): + for i in range(tx.inputs_cnt): # STAGE_REQUEST_2_PREV_INPUT txi = await request_tx_input(tx_req, i, prev_hash) write_tx_input(txh, txi) - write_varint(txh, tx_outputs_count) + write_varint(txh, tx.outputs_cnt) - for o in range(tx_outputs_count): + for o in range(tx.outputs_cnt): # STAGE_REQUEST_2_PREV_OUTPUT txo_bin = await request_tx_output(tx_req, o, prev_hash) write_tx_output(txh, txo_bin) if o == prev_index: total_out += txo_bin.amount - write_uint32(txh, tx_lock_time) + write_uint32(txh, tx.lock_time) if get_tx_hash(txh, True, True) != prev_hash: raise SigningError(FailureType.Other, @@ -366,9 +393,7 @@ def output_is_change(o: TxOutputType) -> bool: def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes: - script_type = i.script_type if i.script_type is not None else InputScriptType.SPENDADDRESS - - if script_type == InputScriptType.SPENDADDRESS: + if i.script_type == InputScriptType.SPENDADDRESS: if signature is None: return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) else: