From 44a3b7f9f111cbe4b88c8738e4f88c0a9feec5d7 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Sun, 6 Nov 2016 14:23:27 +0100 Subject: [PATCH] signtx: fixes, refactoring --- src/apps/common/{sign.py => signtx.py} | 323 ++++++++++++------------- 1 file changed, 157 insertions(+), 166 deletions(-) rename src/apps/common/{sign.py => signtx.py} (58%) diff --git a/src/apps/common/sign.py b/src/apps/common/signtx.py similarity index 58% rename from src/apps/common/sign.py rename to src/apps/common/signtx.py index 008a4f3d9..86b081526 100644 --- a/src/apps/common/sign.py +++ b/src/apps/common/signtx.py @@ -1,6 +1,6 @@ from trezor.crypto.hashlib import sha256, ripemd160 from trezor.crypto.curve import secp256k1 -from trezor.crypto import HDNode, base58 +from trezor.crypto import base58 from . import coins @@ -14,19 +14,10 @@ from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages import OutputScriptType, InputScriptType -# pylint: disable=W0622 - # Machine instructions # === -# TODO: we might want to define these in terms of data instead -# - like TxRequest, but also for deriving keys for example -# - sign_tx would turn to more or less pure code -# - PROBLEM: async defs in python cannot yield. we could ignore that, -# or use wrappers anyway, or just make it an ordinary old-style coroutine -# and use yield / yield from everywhere - def request_tx_meta(prev_hash: bytes=None): ack = yield TxRequest(type=TXMETA, prev_hash=prev_hash) @@ -40,10 +31,10 @@ def request_tx_input(index: int, prev_hash: bytes=None): def request_tx_output(index: int, prev_hash: bytes=None): ack = yield TxRequest(type=TXOUTPUT, prev_hash=prev_hash, index=index) - if prev_hash is not None: - return ack.bin_outputs[0] - else: + if prev_hash is None: return ack.outputs[0] + else: + return ack.bin_outputs[0] def request_tx_finish(): @@ -58,7 +49,7 @@ def send_serialized_tx(serialized: TxRequestSerializedType): # === -async def sign_tx(tx: SignTx, root: HDNode): +async def sign_tx(tx: SignTx, root): coin = coins.by_name(tx.coin_name) # Phase 1 @@ -73,29 +64,31 @@ async def sign_tx(tx: SignTx, root: HDNode): # h_first is used to make sure the inputs and outputs streamed in Phase 1 # are the same as in Phase 2. it is thus not required to fully hash the # tx, as the SignTx info is streamed only once - h_first = tx_hash_init() # not a real tx hash + h_first = HashWriter(sha256) # not a real tx hash + + # pre-allocate the serialization structure for outputs + txo_bin = TxOutputBinType() for i in range(tx.inputs_count): # STAGE_REQUEST_1_INPUT - input = await request_tx_input(i) - tx_write_input(h_first, input) - total_in += await get_prevtx_output_value(input.prev_hash, input.prev_index) + txi = await request_tx_input(i) + write_tx_input(h_first, txi) + total_in += await get_prevtx_output_value(txi.prev_hash, txi.prev_index) for o in range(tx.outputs_count): # STAGE_REQUEST_3_OUTPUT - output = await request_tx_output(o) - if output_is_change(output): + txo = await request_tx_output(o) + if output_is_change(txo): if change_out != 0: - raise ValueError('Only one change output allowed') - change_out = output.amount - outputbin = output_compile(output, coin, root) - tx_write_output(h_first, outputbin) - total_out += outputbin.amount + raise ValueError('Only one change output is valid') + change_out = txo.amount + txo_bin.amount = txo.amount + txo_bin.script_pubkey = output_derive_script(txo, coin, root) + write_tx_output(h_first, txo_bin) + total_out += txo_bin.amount # TODO: display output # TODO: confirm output - h_first_dig = tx_hash_digest(h_first) - # TODO: check funds and tx fee # TODO: ask for confirmation @@ -103,148 +96,193 @@ async def sign_tx(tx: SignTx, root: HDNode): # - sign inputs # - check that nothing changed - for i_sign in range(tx.inputs_count): - h_sign = tx_hash_init() # hash of what we are signing with this input - h_second = tx_hash_init() # should be the same as h_first + # pre-allocated result structure for streaming out the signatures and + # parts of the serialized tx + tx_ser = TxRequestSerializedType() - input_sign = None + for i_sign in range(tx.inputs_count): + # hash of what we are signing with this input + h_sign = HashWriter(sha256) + # same as h_first, checked at the end of this iteration + h_second = HashWriter(sha256) + + txi_sign = None key_sign = None key_sign_pub = None for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT - input = await request_tx_input(i) - tx_write_input(h_second, input) + txi = await request_tx_input(i) + write_tx_input(h_second, txi) if i == i_sign: - key_sign = node_derive(root, input.address_n) + txi_sign = txi + key_sign = node_derive(root, txi.address_n) key_sign_pub = key_sign.public_key() - script_sig = input_derive_script_pre_sign(input, key_sign_pub) - input_sign = input + txi.script_sig = input_derive_script_pre_sign( + txi, key_sign_pub) else: - script_sig = bytes() - input.script_sig = script_sig - tx_write_input(h_sign, input) + txi.script_sig = bytes() + write_tx_input(h_sign, txi) for o in range(tx.outputs_count): # STAGE_REQUEST_4_OUTPUT - output = await request_tx_output(o) - outputbin = output_compile(output, coin, root) - tx_write_output(h_second, outputbin) - tx_write_output(h_sign, outputbin) + txo = await request_tx_output(o) + txo_bin.amount = txo.amount + txo_bin.script_pubkey = output_derive_script(txo, coin, root) + write_tx_output(h_second, txo_bin) + write_tx_output(h_sign, txo_bin) - if h_first_dig != tx_hash_digest(h_second): + h_first_dig = tx_hash_digest(h_first, False) + h_second_dig = tx_hash_digest(h_second, False) + if h_first_dig != h_second_dig: raise ValueError('Transaction has changed during signing') - signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign)) - script_sig = input_derive_script_post_sign( - input, key_sign_pub, signature) - input_sign.script_sig = script_sig + h_sign_dig = tx_hash_digest(h_sign, True) + signature = ecdsa_sign(key_sign, h_sign_dig) + txi_sign.script_sig = input_derive_script_post_sign( + txi_sign, key_sign_pub, signature) # TODO: serialize the whole input at once, including the script_sig - input_sign_w = BufferWriter(bytearray(), 0) - tx_write_input(input_sign_w, input_sign) - input_sign_b = input_sign_w.getvalue() + txi_sign_w = BufferWriter() + write_tx_input(txi_sign_w, txi_sign) + txi_sign_b = txi_sign_w.getvalue() - serialized = TxRequestSerializedType( - signature_index=i_sign, signature=signature, serialized_tx=input_sign_b) - await send_serialized_tx(serialized) + tx_ser.signature_index = i_sign + tx_ser.signature = signature + tx_ser.serialized_tx = txi_sign_b + await send_serialized_tx(tx_ser) + + del tx_ser.signature_index + del tx_ser.signature for o in range(tx.outputs_count): # STAGE_REQUEST_5_OUTPUT - output = await request_tx_output(o) - outputbin = output_compile(output, coin, root) + txo = await request_tx_output(o) + txo_bin.amount = txo.amount + txo_bin.script_pubkey = output_derive_script(txo, coin, root) - outputbin_w = BufferWriter(bytearray(), 0) - tx_write_input(outputbin_w, outputbin) - outputbin_b = outputbin_w.getvalue() + w_txo_bin = BufferWriter() + write_tx_output(w_txo_bin, txo_bin) - serialized = TxRequestSerializedType(serialized_tx=outputbin_b) - await send_serialized_tx(serialized) + tx_ser.serialized_tx = w_txo_bin.getvalue() + await send_serialized_tx(tx_ser) await request_tx_finish() async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int: - - total_in = 0 + total_out = 0 # sum of output amounts # STAGE_REQUEST_2_PREV_META tx = await request_tx_meta(prev_hash) + txh = HashWriter(sha256) - h = tx_hash_init() - tx_write_header(h, tx.version, tx.inputs_count) - + write_tx_header(txh, tx.version, tx.inputs_count) for i in range(tx.inputs_count): # STAGE_REQUEST_2_PREV_INPUT - input = await request_tx_input(i, prev_hash) - tx_write_input(h, input) - - tx_write_middle(h, tx.outputs_count) + txi = await request_tx_input(i, prev_hash) + write_tx_input(txh, txi) + write_tx_middle(txh, tx.outputs_count) for o in range(tx.outputs_count): # STAGE_REQUEST_2_PREV_OUTPUT - outputbin = await request_tx_output(o, prev_hash) - tx_write_output(h, outputbin) + txo_bin = await request_tx_output(o, prev_hash) + write_tx_output(txh, txo_bin) if o == prev_index: - total_in += outputbin.value + total_out += txo_bin.value - tx_write_footer(h, tx.locktime, False) - - if tx_hash_digest(h) != prev_hash: - raise ValueError('PrevTx mismatch') - - return total_in + write_tx_footer(txh, tx.locktime, False) + if tx_hash_digest(txh, True) != prev_hash: + raise ValueError('Encountered invalid prev_hash') + return total_out -# TX Hashing -# === - - -def tx_hash_init(): - return HashWriter(sha256) - - -def tx_hash_digest(w): - return sha256(w.getvalue()).digest() +def tx_hash_digest(w, double: bool): + d = w.getvalue() + if double: + d = sha256(d).digest() + return d # TX Outputs # === -def output_compile(output: TxOutputType, coin: CoinType, root: HDNode) -> TxOutputBinType: - bin = TxOutputBinType() - bin.amount = output.amount - bin.script_pubkey = output_derive_script(output, coin, root) - return bin - - -def output_derive_script(output: TxOutputType, coin: CoinType, root: HDNode) -> bytes: - if output.script_type == OutputScriptType.PAYTOADDRESS: - raw_address = output_paytoaddress_extract_raw_address(output, root) - if raw_address[0] != coin.address_type: # TODO: do this properly - raise ValueError('Invalid address type') - return script_paytoaddress_new(raw_address) +def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: + if o.script_type == OutputScriptType.PAYTOADDRESS: + return script_paytoaddress_new( + output_paytoaddress_extract_raw_address(o, coin, root)) else: - # TODO: other output script types - raise ValueError('Unknown output script type') + raise ValueError('Invalid output script type') return -def output_paytoaddress_extract_raw_address(o: TxOutputType, root: HDNode) -> bytes: +def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root) -> bytes: o_address_n = getattr(o, 'address_n', None) o_address = getattr(o, 'address', None) - if o_address_n: - node = node_derive(root, o_address_n) - # TODO: dont encode and decode again - raw_address = base58.decode_check(node.address()) + # TODO: dont encode/decode more then necessary + # TODO: detect correct address type + if o_address_n is not None: + n = node_derive(root, o_address_n) + raw_address = base58.decode_check(n.address()) elif o_address: raw_address = base58.decode_check(o_address) else: raise ValueError('Missing address') + if raw_address[0] != coin.address_type: + raise ValueError('Invalid address type') return raw_address +def output_is_change(output: TxOutputType): + address_n = getattr(output, 'address_n', None) + return bool(address_n) + + +# Tx Inputs +# === + + +def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes: + if i.script_type == InputScriptType.SPENDADDRESS: + return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) + else: + raise ValueError('Unknown input script type') + + +def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes: + if i.script_type == InputScriptType.SPENDADDRESS: + return script_spendaddress_new(pubkey, signature) + else: + raise ValueError('Unknown input script type') + + +def node_derive(root, address_n: list): + node = root.clone() + node.derive_path(address_n) + return node + + +def ecdsa_hash_pubkey(pubkey: bytes) -> bytes: + if pubkey[0] == 0x04: + assert len(pubkey) == 65 # uncompressed format + elif pubkey[0] == 0x00: + assert len(pubkey) == 1 # point at infinity + else: + assert len(pubkey) == 33 # compresssed format + h = sha256(pubkey).digest() + h = ripemd160(h).digest() + return h + + +def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes: + return secp256k1.sign(privkey, digest) + + +# TX Scripts +# === + + def script_paytoaddress_new(raw_address: bytes) -> bytearray: s = bytearray(25) s[0] = 0x76 # OP_DUP @@ -256,75 +294,26 @@ def script_paytoaddress_new(raw_address: bytes) -> bytearray: return s -def output_is_change(output: TxOutputType): - address_n = getattr(output, 'address_n', None) - return bool(address_n) - - -# Tx Inputs -# === - - -def input_derive_script_pre_sign(input: TxInputType, pubkey: bytes) -> bytes: - if input.script_type == InputScriptType.SPENDADDRESS: - return script_paytoaddress_new(ecdsa_get_pubkeyhash(pubkey)) - else: - # TODO: other input script types - raise ValueError('Unknown input script type') - - -def input_derive_script_post_sign(input: TxInputType, pubkey: bytes, signature: bytes) -> bytes: - if input.script_type == InputScriptType.SPENDADDRESS: - return script_spendaddress_new(pubkey, signature) - else: - # TODO: other input script types - raise ValueError('Unknown input script type') - - def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray: - s = bytearray(25) - w = BufferWriter(s, 0) + w = BufferWriter() write_op_push(w, len(signature) + 1) write_bytes(w, signature) w.writebyte(0x01) write_op_push(w, len(pubkey)) write_bytes(w, pubkey) - return - - -def node_derive(root: HDNode, address_n: list) -> HDNode: - # TODO: this will probably need to be a part of the machine instructions - node = root.clone() - node.derive_path(address_n) - return node - - -def ecdsa_get_pubkeyhash(pubkey: bytes) -> bytes: - if pubkey[0] == 0x04: - assert len(pubkey) == 65 # uncompressed format - elif pubkey[0] == 0x00: - assert len(pubkey) == 1 # point at infinity - else: - assert len(pubkey) == 33 # compresssed format - h = sha256(pubkey).digest() - h = ripemd160(h).digest() - return h - - -async def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes: - return secp256k1.sign(privkey, digest) + return w.getvalue() # TX Serialization # === -def tx_write_header(w, version: int, inputs_count: int): +def write_tx_header(w, version: int, inputs_count: int): write_uint32(w, version) write_varint(w, inputs_count) -def tx_write_input(w, i: TxInputType): +def write_tx_input(w, i: TxInputType): write_bytes_rev(w, i.prev_hash) write_uint32(w, i.prev_index) write_varint(w, len(i.script_sig)) @@ -332,17 +321,17 @@ def tx_write_input(w, i: TxInputType): write_uint32(w, i.sequence) -def tx_write_middle(w, outputs_count: int): +def write_tx_middle(w, outputs_count: int): write_varint(w, outputs_count) -def tx_write_output(w, o: TxOutputBinType): +def write_tx_output(w, o: TxOutputBinType): write_uint64(w, o.amount) write_varint(w, len(o.script_pubkey)) write_bytes(w, o.script_pubkey) -def tx_write_footer(w, locktime: int, add_hash_type: bool): +def write_tx_footer(w, locktime: int, add_hash_type: bool): write_uint32(w, locktime) if add_hash_type: write_uint32(w, 1) @@ -417,13 +406,15 @@ def write_bytes_rev(w, buf: bytearray): class BufferWriter: - def __init__(self, buf: bytearray, ofs: int): + def __init__(self, buf: bytearray=None, ofs: int=0): # TODO: re-think the use of bytearrays, buffers, and other byte IO # i think we should just pass a pre-allocation size here, allocate the # bytearray and then trim it to zero. in this case, write() would # correspond to extend(), and writebyte() to append(). of course, the # the use-case of non-destructively writing to existing bytearray still # exists. + if buf is None: + buf = bytearray() self.buf = buf self.ofs = ofs