diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 6dad23929..cbb14cad8 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -129,7 +129,9 @@ async def sign_tx(tx: SignTx, root): key_sign = None key_sign_pub = None - write_tx_header(h_sign, tx_version, tx_inputs_count) + write_uint32(h_sign, tx_version) + + write_varint(h_sign, tx_inputs_count) for i in range(tx_inputs_count): # STAGE_REQUEST_4_INPUT @@ -145,7 +147,7 @@ async def sign_tx(tx: SignTx, root): txi.script_sig = bytes() write_tx_input(h_sign, txi) - write_tx_middle(h_sign, tx_outputs_count) + write_varint(h_sign, tx_outputs_count) for o in range(tx_outputs_count): # STAGE_REQUEST_4_OUTPUT @@ -155,7 +157,9 @@ async def sign_tx(tx: SignTx, root): write_tx_output(h_second, txo_bin) write_tx_output(h_sign, txo_bin) - write_tx_footer(h_sign, tx_lock_time, True) + write_uint32(h_sign, tx_lock_time) + + write_uint32(h_sign, 0x0000001) # hash_type import ubinascii @@ -172,8 +176,9 @@ async def sign_tx(tx: SignTx, root): txi_sign.script_sig = input_derive_script_post_sign( txi_sign, key_sign_pub, signature) w_txi_sign = BufferWriter() - if i_sign == 0: - write_tx_header(w_txi_sign, tx_version, tx_inputs_count) + if i_sign == 0: # serializing first input => prepend tx version and inputs count + write_uint32(w_txi_sign, tx_version) + write_varint(w_txi_sign, tx_inputs_count) write_tx_input(w_txi_sign, txi_sign) tx_ser.serialized_tx = w_txi_sign.getvalue() @@ -187,11 +192,11 @@ async def sign_tx(tx: SignTx, root): # serialize output w_txo_bin = BufferWriter() - if o == 0: - write_tx_middle(w_txo_bin, tx_outputs_count) + if o == 0: # serializing first output => prepend outputs count + write_varint(w_txo_bin, tx_outputs_count) write_tx_output(w_txo_bin, txo_bin) - if o == tx_outputs_count - 1: - write_tx_footer(w_txo_bin, tx_lock_time, False) + if o == tx_outputs_count - 1: # serializing last output => append tx lock_time + write_uint32(w_txo_bin, tx_lock_time) tx_ser.signature_index = None tx_ser.signature = None tx_ser.serialized_tx = w_txo_bin.getvalue() @@ -214,14 +219,16 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde txh = HashWriter(sha256) - write_tx_header(txh, tx_version, tx_inputs_count) + write_uint32(txh, tx_version) + + write_varint(txh, tx_inputs_count) for i in range(tx_inputs_count): # STAGE_REQUEST_2_PREV_INPUT txi = await request_tx_input(tx_req, i, prev_hash) write_tx_input(txh, txi) - write_tx_middle(txh, tx_outputs_count) + write_varint(txh, tx_outputs_count) for o in range(tx_outputs_count): # STAGE_REQUEST_2_PREV_OUTPUT @@ -230,11 +237,12 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde if o == prev_index: total_out += txo_bin.amount - write_tx_footer(txh, tx_lock_time, False) + write_uint32(txh, tx_lock_time) prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance if tx_hash_digest(txh, True) != prev_hash_rev: raise ValueError('Encountered invalid prev_hash') + return total_out @@ -352,12 +360,6 @@ def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray: # TX Serialization # === - -def write_tx_header(w, version: int, inputs_count: int): - write_uint32(w, version) - write_varint(w, inputs_count) - - def write_tx_input(w, i: TxInputType): i_sequence = getattr(i, 'sequence', 4294967295) write_bytes_rev(w, i.prev_hash) @@ -377,22 +379,12 @@ def write_tx_input_check(w, i: TxInputType): write_uint32(w, i_sequence) -def write_tx_middle(w, outputs_count: int): - write_varint(w, outputs_count) - - def write_tx_output(w, o: TxOutputBinType): write_uint64(w, o.amount) write_varint(w, len(o.script_pubkey)) write_bytes(w, o.script_pubkey) -def write_tx_footer(w, locktime: int, add_hash_type: bool): - write_uint32(w, locktime) - if add_hash_type: - write_uint32(w, 1) - - def write_op_push(w, n: int): wb = w.writebyte if n < 0x4C: