diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 86b081526b..2c6e585aa4 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType from trezor.messages.TxRequest import TxRequest from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.TxRequestSerializedType import TxRequestSerializedType +from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages import OutputScriptType, InputScriptType @@ -19,30 +20,41 @@ from trezor.messages import OutputScriptType, InputScriptType # === -def request_tx_meta(prev_hash: bytes=None): - ack = yield TxRequest(type=TXMETA, prev_hash=prev_hash) +def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): + tx_req.type = TXMETA + tx_req.details.tx_hash = tx_hash + tx_req.details.request_index = None + ack = yield tx_req + tx_req.serialized = None return ack.tx -def request_tx_input(index: int, prev_hash: bytes=None): - ack = yield TxRequest(type=TXINPUT, prev_hash=prev_hash, index=index) +def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): + tx_req.type = TXINPUT + tx_req.details.request_index = i + tx_req.details.tx_hash = tx_hash + ack = yield tx_req + tx_req.serialized = None return ack.tx.inputs[0] -def request_tx_output(index: int, prev_hash: bytes=None): - ack = yield TxRequest(type=TXOUTPUT, prev_hash=prev_hash, index=index) - if prev_hash is None: +def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None): + tx_req.type = TXOUTPUT + tx_req.details.request_index = i + tx_req.details.tx_hash = tx_hash + ack = yield tx_req + tx_req.serialized = None + if tx_hash is None: return ack.outputs[0] else: return ack.bin_outputs[0] -def request_tx_finish(): - yield TxRequest(type=TXFINISHED) - - -def send_serialized_tx(serialized: TxRequestSerializedType): - yield serialized +def request_tx_finish(tx_req: TxRequest): + tx_req.type = TXFINISHED + tx_req.details = None + yield tx_req + tx_req.serialized = None # Transaction signing @@ -50,7 +62,13 @@ def send_serialized_tx(serialized: TxRequestSerializedType): async def sign_tx(tx: SignTx, root): - coin = coins.by_name(tx.coin_name) + tx_version = getattr(tx, 'version', 0) + tx_lock_time = getattr(tx, 'lock_time', 1) + tx_inputs_count = getattr(tx, 'inputs_count', 0) + tx_outputs_count = getattr(tx, 'outputs_count', 0) + coin_name = getattr(tx, 'coin_name', 'Bitcoin') + + coin = coins.by_name(coin_name) # Phase 1 # - check inputs, previous transactions, and outputs @@ -66,18 +84,20 @@ async def sign_tx(tx: SignTx, root): # tx, as the SignTx info is streamed only once h_first = HashWriter(sha256) # not a real tx hash - # pre-allocate the serialization structure for outputs txo_bin = TxOutputBinType() + tx_req = TxRequest() + tx_req.details = TxRequestDetailsType() - for i in range(tx.inputs_count): + for i in range(tx_inputs_count): # STAGE_REQUEST_1_INPUT - txi = await request_tx_input(i) + txi = await request_tx_input(tx_req, i) write_tx_input(h_first, txi) - total_in += await get_prevtx_output_value(txi.prev_hash, txi.prev_index) + total_in += await get_prevtx_output_value( + tx_req, txi.prev_hash, txi.prev_index) - for o in range(tx.outputs_count): + for o in range(tx_outputs_count): # STAGE_REQUEST_3_OUTPUT - txo = await request_tx_output(o) + txo = await request_tx_output(tx_req, o) if output_is_change(txo): if change_out != 0: raise ValueError('Only one change output is valid') @@ -96,8 +116,6 @@ async def sign_tx(tx: SignTx, root): # - sign inputs # - check that nothing changed - # pre-allocated result structure for streaming out the signatures and - # parts of the serialized tx tx_ser = TxRequestSerializedType() for i_sign in range(tx.inputs_count): @@ -110,9 +128,11 @@ async def sign_tx(tx: SignTx, root): key_sign = None key_sign_pub = None + write_tx_header(h_sign, tx_version, tx_inputs_count) + for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT - txi = await request_tx_input(i) + txi = await request_tx_input(tx_req, i) write_tx_input(h_second, txi) if i == i_sign: txi_sign = txi @@ -124,74 +144,94 @@ async def sign_tx(tx: SignTx, root): txi.script_sig = bytes() write_tx_input(h_sign, txi) + write_tx_middle(h_sign, tx_outputs_count) + for o in range(tx.outputs_count): # STAGE_REQUEST_4_OUTPUT - txo = await request_tx_output(o) + txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount txo_bin.script_pubkey = output_derive_script(txo, coin, root) write_tx_output(h_second, txo_bin) write_tx_output(h_sign, txo_bin) + write_tx_footer(h_sign, tx_lock_time, True) + + # check the control digests h_first_dig = tx_hash_digest(h_first, False) h_second_dig = tx_hash_digest(h_second, False) if h_first_dig != h_second_dig: raise ValueError('Transaction has changed during signing') + # compute the signature from the tx digest h_sign_dig = tx_hash_digest(h_sign, True) signature = ecdsa_sign(key_sign, h_sign_dig) - txi_sign.script_sig = input_derive_script_post_sign( - txi_sign, key_sign_pub, signature) - - # TODO: serialize the whole input at once, including the script_sig - txi_sign_w = BufferWriter() - write_tx_input(txi_sign_w, txi_sign) - txi_sign_b = txi_sign_w.getvalue() - tx_ser.signature_index = i_sign tx_ser.signature = signature - tx_ser.serialized_tx = txi_sign_b - await send_serialized_tx(tx_ser) - del tx_ser.signature_index - del tx_ser.signature + # serialize input with correct signature + txi_sign.script_sig = input_derive_script_post_sign( + txi_sign, key_sign_pub, signature) + txi_sign_w = BufferWriter() + if i_sign == 0: + write_tx_header(txi_sign_w, tx_version, tx_inputs_count) + write_tx_input(txi_sign_w, txi_sign) + tx_ser.serialized_tx = txi_sign_w.getvalue() + + tx_req.serialized = tx_ser for o in range(tx.outputs_count): # STAGE_REQUEST_5_OUTPUT - txo = await request_tx_output(o) + txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount txo_bin.script_pubkey = output_derive_script(txo, coin, root) + # serialize output w_txo_bin = BufferWriter() + if o == 0: + write_tx_middle(w_txo_bin, tx_outputs_count) write_tx_output(w_txo_bin, txo_bin) - + if o == tx_outputs_count: + write_tx_footer(w_txo_bin, tx_lock_time, False) + tx_ser.signature_index = None + tx_ser.signature = None tx_ser.serialized_tx = w_txo_bin.getvalue() - await send_serialized_tx(tx_ser) - await request_tx_finish() + tx_req.serialized = tx_ser + + await request_tx_finish(tx_req) -async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int: +async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int: total_out = 0 # sum of output amounts # STAGE_REQUEST_2_PREV_META tx = await request_tx_meta(prev_hash) + + tx_version = getattr(tx, 'version', 0) + tx_lock_time = getattr(tx, 'lock_time', 1) + tx_inputs_count = getattr(tx, 'inputs_count', 0) + tx_outputs_count = getattr(tx, 'outputs_count', 0) + txh = HashWriter(sha256) - write_tx_header(txh, tx.version, tx.inputs_count) - for i in range(tx.inputs_count): + write_tx_header(txh, tx_version, tx_inputs_count) + + for i in range(tx_inputs_count): # STAGE_REQUEST_2_PREV_INPUT - txi = await request_tx_input(i, prev_hash) + txi = await request_tx_input(tx_req, i, prev_hash) write_tx_input(txh, txi) - write_tx_middle(txh, tx.outputs_count) - for o in range(tx.outputs_count): + write_tx_middle(txh, tx_outputs_count) + + for o in range(tx_outputs_count): # STAGE_REQUEST_2_PREV_OUTPUT - txo_bin = await request_tx_output(o, prev_hash) + txo_bin = await request_tx_output(tx_req, o, prev_hash) write_tx_output(txh, txo_bin) if o == prev_index: total_out += txo_bin.value - write_tx_footer(txh, tx.locktime, False) + write_tx_footer(txh, tx_lock_time, False) + if tx_hash_digest(txh, True) != prev_hash: raise ValueError('Encountered invalid prev_hash') return total_out