1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-08-04 12:56:25 +00:00

signtx: fixes, refactoring

This commit is contained in:
Jan Pochyla 2016-11-06 14:23:27 +01:00
parent adc3dde19e
commit 44a3b7f9f1

View File

@ -1,6 +1,6 @@
from trezor.crypto.hashlib import sha256, ripemd160 from trezor.crypto.hashlib import sha256, ripemd160
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto import HDNode, base58 from trezor.crypto import base58
from . import coins from . import coins
@ -14,19 +14,10 @@ from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType, InputScriptType from trezor.messages import OutputScriptType, InputScriptType
# pylint: disable=W0622
# Machine instructions # Machine instructions
# === # ===
# TODO: we might want to define these in terms of data instead
# - like TxRequest, but also for deriving keys for example
# - sign_tx would turn to more or less pure code
# - PROBLEM: async defs in python cannot yield. we could ignore that,
# or use wrappers anyway, or just make it an ordinary old-style coroutine
# and use yield / yield from everywhere
def request_tx_meta(prev_hash: bytes=None): def request_tx_meta(prev_hash: bytes=None):
ack = yield TxRequest(type=TXMETA, prev_hash=prev_hash) ack = yield TxRequest(type=TXMETA, prev_hash=prev_hash)
@ -40,10 +31,10 @@ def request_tx_input(index: int, prev_hash: bytes=None):
def request_tx_output(index: int, prev_hash: bytes=None): def request_tx_output(index: int, prev_hash: bytes=None):
ack = yield TxRequest(type=TXOUTPUT, prev_hash=prev_hash, index=index) ack = yield TxRequest(type=TXOUTPUT, prev_hash=prev_hash, index=index)
if prev_hash is not None: if prev_hash is None:
return ack.bin_outputs[0]
else:
return ack.outputs[0] return ack.outputs[0]
else:
return ack.bin_outputs[0]
def request_tx_finish(): def request_tx_finish():
@ -58,7 +49,7 @@ def send_serialized_tx(serialized: TxRequestSerializedType):
# === # ===
async def sign_tx(tx: SignTx, root: HDNode): async def sign_tx(tx: SignTx, root):
coin = coins.by_name(tx.coin_name) coin = coins.by_name(tx.coin_name)
# Phase 1 # Phase 1
@ -73,29 +64,31 @@ async def sign_tx(tx: SignTx, root: HDNode):
# h_first is used to make sure the inputs and outputs streamed in Phase 1 # h_first is used to make sure the inputs and outputs streamed in Phase 1
# are the same as in Phase 2. it is thus not required to fully hash the # are the same as in Phase 2. it is thus not required to fully hash the
# tx, as the SignTx info is streamed only once # tx, as the SignTx info is streamed only once
h_first = tx_hash_init() # not a real tx hash h_first = HashWriter(sha256) # not a real tx hash
# pre-allocate the serialization structure for outputs
txo_bin = TxOutputBinType()
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT # STAGE_REQUEST_1_INPUT
input = await request_tx_input(i) txi = await request_tx_input(i)
tx_write_input(h_first, input) write_tx_input(h_first, txi)
total_in += await get_prevtx_output_value(input.prev_hash, input.prev_index) total_in += await get_prevtx_output_value(txi.prev_hash, txi.prev_index)
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT # STAGE_REQUEST_3_OUTPUT
output = await request_tx_output(o) txo = await request_tx_output(o)
if output_is_change(output): if output_is_change(txo):
if change_out != 0: if change_out != 0:
raise ValueError('Only one change output allowed') raise ValueError('Only one change output is valid')
change_out = output.amount change_out = txo.amount
outputbin = output_compile(output, coin, root) txo_bin.amount = txo.amount
tx_write_output(h_first, outputbin) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
total_out += outputbin.amount write_tx_output(h_first, txo_bin)
total_out += txo_bin.amount
# TODO: display output # TODO: display output
# TODO: confirm output # TODO: confirm output
h_first_dig = tx_hash_digest(h_first)
# TODO: check funds and tx fee # TODO: check funds and tx fee
# TODO: ask for confirmation # TODO: ask for confirmation
@ -103,148 +96,193 @@ async def sign_tx(tx: SignTx, root: HDNode):
# - sign inputs # - sign inputs
# - check that nothing changed # - check that nothing changed
for i_sign in range(tx.inputs_count): # pre-allocated result structure for streaming out the signatures and
h_sign = tx_hash_init() # hash of what we are signing with this input # parts of the serialized tx
h_second = tx_hash_init() # should be the same as h_first tx_ser = TxRequestSerializedType()
input_sign = None for i_sign in range(tx.inputs_count):
# hash of what we are signing with this input
h_sign = HashWriter(sha256)
# same as h_first, checked at the end of this iteration
h_second = HashWriter(sha256)
txi_sign = None
key_sign = None key_sign = None
key_sign_pub = None key_sign_pub = None
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT # STAGE_REQUEST_4_INPUT
input = await request_tx_input(i) txi = await request_tx_input(i)
tx_write_input(h_second, input) write_tx_input(h_second, txi)
if i == i_sign: if i == i_sign:
key_sign = node_derive(root, input.address_n) txi_sign = txi
key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
script_sig = input_derive_script_pre_sign(input, key_sign_pub) txi.script_sig = input_derive_script_pre_sign(
input_sign = input txi, key_sign_pub)
else: else:
script_sig = bytes() txi.script_sig = bytes()
input.script_sig = script_sig write_tx_input(h_sign, txi)
tx_write_input(h_sign, input)
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT # STAGE_REQUEST_4_OUTPUT
output = await request_tx_output(o) txo = await request_tx_output(o)
outputbin = output_compile(output, coin, root) txo_bin.amount = txo.amount
tx_write_output(h_second, outputbin) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
tx_write_output(h_sign, outputbin) write_tx_output(h_second, txo_bin)
write_tx_output(h_sign, txo_bin)
if h_first_dig != tx_hash_digest(h_second): h_first_dig = tx_hash_digest(h_first, False)
h_second_dig = tx_hash_digest(h_second, False)
if h_first_dig != h_second_dig:
raise ValueError('Transaction has changed during signing') raise ValueError('Transaction has changed during signing')
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign)) h_sign_dig = tx_hash_digest(h_sign, True)
script_sig = input_derive_script_post_sign( signature = ecdsa_sign(key_sign, h_sign_dig)
input, key_sign_pub, signature) txi_sign.script_sig = input_derive_script_post_sign(
input_sign.script_sig = script_sig txi_sign, key_sign_pub, signature)
# TODO: serialize the whole input at once, including the script_sig # TODO: serialize the whole input at once, including the script_sig
input_sign_w = BufferWriter(bytearray(), 0) txi_sign_w = BufferWriter()
tx_write_input(input_sign_w, input_sign) write_tx_input(txi_sign_w, txi_sign)
input_sign_b = input_sign_w.getvalue() txi_sign_b = txi_sign_w.getvalue()
serialized = TxRequestSerializedType( tx_ser.signature_index = i_sign
signature_index=i_sign, signature=signature, serialized_tx=input_sign_b) tx_ser.signature = signature
await send_serialized_tx(serialized) tx_ser.serialized_tx = txi_sign_b
await send_serialized_tx(tx_ser)
del tx_ser.signature_index
del tx_ser.signature
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_5_OUTPUT # STAGE_REQUEST_5_OUTPUT
output = await request_tx_output(o) txo = await request_tx_output(o)
outputbin = output_compile(output, coin, root) txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
outputbin_w = BufferWriter(bytearray(), 0) w_txo_bin = BufferWriter()
tx_write_input(outputbin_w, outputbin) write_tx_output(w_txo_bin, txo_bin)
outputbin_b = outputbin_w.getvalue()
serialized = TxRequestSerializedType(serialized_tx=outputbin_b) tx_ser.serialized_tx = w_txo_bin.getvalue()
await send_serialized_tx(serialized) await send_serialized_tx(tx_ser)
await request_tx_finish() await request_tx_finish()
async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int:
total_out = 0 # sum of output amounts
total_in = 0
# STAGE_REQUEST_2_PREV_META # STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(prev_hash) tx = await request_tx_meta(prev_hash)
txh = HashWriter(sha256)
h = tx_hash_init() write_tx_header(txh, tx.version, tx.inputs_count)
tx_write_header(h, tx.version, tx.inputs_count)
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_2_PREV_INPUT # STAGE_REQUEST_2_PREV_INPUT
input = await request_tx_input(i, prev_hash) txi = await request_tx_input(i, prev_hash)
tx_write_input(h, input) write_tx_input(txh, txi)
tx_write_middle(h, tx.outputs_count)
write_tx_middle(txh, tx.outputs_count)
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_2_PREV_OUTPUT # STAGE_REQUEST_2_PREV_OUTPUT
outputbin = await request_tx_output(o, prev_hash) txo_bin = await request_tx_output(o, prev_hash)
tx_write_output(h, outputbin) write_tx_output(txh, txo_bin)
if o == prev_index: if o == prev_index:
total_in += outputbin.value total_out += txo_bin.value
tx_write_footer(h, tx.locktime, False) write_tx_footer(txh, tx.locktime, False)
if tx_hash_digest(txh, True) != prev_hash:
if tx_hash_digest(h) != prev_hash: raise ValueError('Encountered invalid prev_hash')
raise ValueError('PrevTx mismatch') return total_out
return total_in
# TX Hashing def tx_hash_digest(w, double: bool):
# === d = w.getvalue()
if double:
d = sha256(d).digest()
def tx_hash_init(): return d
return HashWriter(sha256)
def tx_hash_digest(w):
return sha256(w.getvalue()).digest()
# TX Outputs # TX Outputs
# === # ===
def output_compile(output: TxOutputType, coin: CoinType, root: HDNode) -> TxOutputBinType: def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
bin = TxOutputBinType() if o.script_type == OutputScriptType.PAYTOADDRESS:
bin.amount = output.amount return script_paytoaddress_new(
bin.script_pubkey = output_derive_script(output, coin, root) output_paytoaddress_extract_raw_address(o, coin, root))
return bin
def output_derive_script(output: TxOutputType, coin: CoinType, root: HDNode) -> bytes:
if output.script_type == OutputScriptType.PAYTOADDRESS:
raw_address = output_paytoaddress_extract_raw_address(output, root)
if raw_address[0] != coin.address_type: # TODO: do this properly
raise ValueError('Invalid address type')
return script_paytoaddress_new(raw_address)
else: else:
# TODO: other output script types raise ValueError('Invalid output script type')
raise ValueError('Unknown output script type')
return return
def output_paytoaddress_extract_raw_address(o: TxOutputType, root: HDNode) -> bytes: def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root) -> bytes:
o_address_n = getattr(o, 'address_n', None) o_address_n = getattr(o, 'address_n', None)
o_address = getattr(o, 'address', None) o_address = getattr(o, 'address', None)
if o_address_n: # TODO: dont encode/decode more then necessary
node = node_derive(root, o_address_n) # TODO: detect correct address type
# TODO: dont encode and decode again if o_address_n is not None:
raw_address = base58.decode_check(node.address()) n = node_derive(root, o_address_n)
raw_address = base58.decode_check(n.address())
elif o_address: elif o_address:
raw_address = base58.decode_check(o_address) raw_address = base58.decode_check(o_address)
else: else:
raise ValueError('Missing address') raise ValueError('Missing address')
if raw_address[0] != coin.address_type:
raise ValueError('Invalid address type')
return raw_address return raw_address
def output_is_change(output: TxOutputType):
address_n = getattr(output, 'address_n', None)
return bool(address_n)
# Tx Inputs
# ===
def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else:
raise ValueError('Unknown input script type')
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
raise ValueError('Unknown input script type')
def node_derive(root, address_n: list):
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04:
assert len(pubkey) == 65 # uncompressed format
elif pubkey[0] == 0x00:
assert len(pubkey) == 1 # point at infinity
else:
assert len(pubkey) == 33 # compresssed format
h = sha256(pubkey).digest()
h = ripemd160(h).digest()
return h
def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
return secp256k1.sign(privkey, digest)
# TX Scripts
# ===
def script_paytoaddress_new(raw_address: bytes) -> bytearray: def script_paytoaddress_new(raw_address: bytes) -> bytearray:
s = bytearray(25) s = bytearray(25)
s[0] = 0x76 # OP_DUP s[0] = 0x76 # OP_DUP
@ -256,75 +294,26 @@ def script_paytoaddress_new(raw_address: bytes) -> bytearray:
return s return s
def output_is_change(output: TxOutputType):
address_n = getattr(output, 'address_n', None)
return bool(address_n)
# Tx Inputs
# ===
def input_derive_script_pre_sign(input: TxInputType, pubkey: bytes) -> bytes:
if input.script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_get_pubkeyhash(pubkey))
else:
# TODO: other input script types
raise ValueError('Unknown input script type')
def input_derive_script_post_sign(input: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
if input.script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
# TODO: other input script types
raise ValueError('Unknown input script type')
def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray: def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
s = bytearray(25) w = BufferWriter()
w = BufferWriter(s, 0)
write_op_push(w, len(signature) + 1) write_op_push(w, len(signature) + 1)
write_bytes(w, signature) write_bytes(w, signature)
w.writebyte(0x01) w.writebyte(0x01)
write_op_push(w, len(pubkey)) write_op_push(w, len(pubkey))
write_bytes(w, pubkey) write_bytes(w, pubkey)
return return w.getvalue()
def node_derive(root: HDNode, address_n: list) -> HDNode:
# TODO: this will probably need to be a part of the machine instructions
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_get_pubkeyhash(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04:
assert len(pubkey) == 65 # uncompressed format
elif pubkey[0] == 0x00:
assert len(pubkey) == 1 # point at infinity
else:
assert len(pubkey) == 33 # compresssed format
h = sha256(pubkey).digest()
h = ripemd160(h).digest()
return h
async def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
return secp256k1.sign(privkey, digest)
# TX Serialization # TX Serialization
# === # ===
def tx_write_header(w, version: int, inputs_count: int): def write_tx_header(w, version: int, inputs_count: int):
write_uint32(w, version) write_uint32(w, version)
write_varint(w, inputs_count) write_varint(w, inputs_count)
def tx_write_input(w, i: TxInputType): def write_tx_input(w, i: TxInputType):
write_bytes_rev(w, i.prev_hash) write_bytes_rev(w, i.prev_hash)
write_uint32(w, i.prev_index) write_uint32(w, i.prev_index)
write_varint(w, len(i.script_sig)) write_varint(w, len(i.script_sig))
@ -332,17 +321,17 @@ def tx_write_input(w, i: TxInputType):
write_uint32(w, i.sequence) write_uint32(w, i.sequence)
def tx_write_middle(w, outputs_count: int): def write_tx_middle(w, outputs_count: int):
write_varint(w, outputs_count) write_varint(w, outputs_count)
def tx_write_output(w, o: TxOutputBinType): def write_tx_output(w, o: TxOutputBinType):
write_uint64(w, o.amount) write_uint64(w, o.amount)
write_varint(w, len(o.script_pubkey)) write_varint(w, len(o.script_pubkey))
write_bytes(w, o.script_pubkey) write_bytes(w, o.script_pubkey)
def tx_write_footer(w, locktime: int, add_hash_type: bool): def write_tx_footer(w, locktime: int, add_hash_type: bool):
write_uint32(w, locktime) write_uint32(w, locktime)
if add_hash_type: if add_hash_type:
write_uint32(w, 1) write_uint32(w, 1)
@ -417,13 +406,15 @@ def write_bytes_rev(w, buf: bytearray):
class BufferWriter: class BufferWriter:
def __init__(self, buf: bytearray, ofs: int): def __init__(self, buf: bytearray=None, ofs: int=0):
# TODO: re-think the use of bytearrays, buffers, and other byte IO # TODO: re-think the use of bytearrays, buffers, and other byte IO
# i think we should just pass a pre-allocation size here, allocate the # i think we should just pass a pre-allocation size here, allocate the
# bytearray and then trim it to zero. in this case, write() would # bytearray and then trim it to zero. in this case, write() would
# correspond to extend(), and writebyte() to append(). of course, the # correspond to extend(), and writebyte() to append(). of course, the
# the use-case of non-destructively writing to existing bytearray still # the use-case of non-destructively writing to existing bytearray still
# exists. # exists.
if buf is None:
buf = bytearray()
self.buf = buf self.buf = buf
self.ofs = ofs self.ofs = ofs