1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-05 13:01:12 +00:00

signtx: serialize tx meta

This commit is contained in:
Jan Pochyla 2016-11-07 17:00:00 +01:00
parent 0012883984
commit 1bb20c2521

View File

@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType, InputScriptType from trezor.messages import OutputScriptType, InputScriptType
@ -19,30 +20,41 @@ from trezor.messages import OutputScriptType, InputScriptType
# === # ===
def request_tx_meta(prev_hash: bytes=None): def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
ack = yield TxRequest(type=TXMETA, prev_hash=prev_hash) tx_req.type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
ack = yield tx_req
tx_req.serialized = None
return ack.tx return ack.tx
def request_tx_input(index: int, prev_hash: bytes=None): def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
ack = yield TxRequest(type=TXINPUT, prev_hash=prev_hash, index=index) tx_req.type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
return ack.tx.inputs[0] return ack.tx.inputs[0]
def request_tx_output(index: int, prev_hash: bytes=None): def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
ack = yield TxRequest(type=TXOUTPUT, prev_hash=prev_hash, index=index) tx_req.type = TXOUTPUT
if prev_hash is None: tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return ack.outputs[0] return ack.outputs[0]
else: else:
return ack.bin_outputs[0] return ack.bin_outputs[0]
def request_tx_finish(): def request_tx_finish(tx_req: TxRequest):
yield TxRequest(type=TXFINISHED) tx_req.type = TXFINISHED
tx_req.details = None
yield tx_req
def send_serialized_tx(serialized: TxRequestSerializedType): tx_req.serialized = None
yield serialized
# Transaction signing # Transaction signing
@ -50,7 +62,13 @@ def send_serialized_tx(serialized: TxRequestSerializedType):
async def sign_tx(tx: SignTx, root): async def sign_tx(tx: SignTx, root):
coin = coins.by_name(tx.coin_name) tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
coin_name = getattr(tx, 'coin_name', 'Bitcoin')
coin = coins.by_name(coin_name)
# Phase 1 # Phase 1
# - check inputs, previous transactions, and outputs # - check inputs, previous transactions, and outputs
@ -66,18 +84,20 @@ async def sign_tx(tx: SignTx, root):
# tx, as the SignTx info is streamed only once # tx, as the SignTx info is streamed only once
h_first = HashWriter(sha256) # not a real tx hash h_first = HashWriter(sha256) # not a real tx hash
# pre-allocate the serialization structure for outputs
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
for i in range(tx.inputs_count): for i in range(tx_inputs_count):
# STAGE_REQUEST_1_INPUT # STAGE_REQUEST_1_INPUT
txi = await request_tx_input(i) txi = await request_tx_input(tx_req, i)
write_tx_input(h_first, txi) write_tx_input(h_first, txi)
total_in += await get_prevtx_output_value(txi.prev_hash, txi.prev_index) total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
for o in range(tx.outputs_count): for o in range(tx_outputs_count):
# STAGE_REQUEST_3_OUTPUT # STAGE_REQUEST_3_OUTPUT
txo = await request_tx_output(o) txo = await request_tx_output(tx_req, o)
if output_is_change(txo): if output_is_change(txo):
if change_out != 0: if change_out != 0:
raise ValueError('Only one change output is valid') raise ValueError('Only one change output is valid')
@ -96,8 +116,6 @@ async def sign_tx(tx: SignTx, root):
# - sign inputs # - sign inputs
# - check that nothing changed # - check that nothing changed
# pre-allocated result structure for streaming out the signatures and
# parts of the serialized tx
tx_ser = TxRequestSerializedType() tx_ser = TxRequestSerializedType()
for i_sign in range(tx.inputs_count): for i_sign in range(tx.inputs_count):
@ -110,9 +128,11 @@ async def sign_tx(tx: SignTx, root):
key_sign = None key_sign = None
key_sign_pub = None key_sign_pub = None
write_tx_header(h_sign, tx_version, tx_inputs_count)
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT # STAGE_REQUEST_4_INPUT
txi = await request_tx_input(i) txi = await request_tx_input(tx_req, i)
write_tx_input(h_second, txi) write_tx_input(h_second, txi)
if i == i_sign: if i == i_sign:
txi_sign = txi txi_sign = txi
@ -124,74 +144,94 @@ async def sign_tx(tx: SignTx, root):
txi.script_sig = bytes() txi.script_sig = bytes()
write_tx_input(h_sign, txi) write_tx_input(h_sign, txi)
write_tx_middle(h_sign, tx_outputs_count)
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT # STAGE_REQUEST_4_OUTPUT
txo = await request_tx_output(o) txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_second, txo_bin) write_tx_output(h_second, txo_bin)
write_tx_output(h_sign, txo_bin) write_tx_output(h_sign, txo_bin)
write_tx_footer(h_sign, tx_lock_time, True)
# check the control digests
h_first_dig = tx_hash_digest(h_first, False) h_first_dig = tx_hash_digest(h_first, False)
h_second_dig = tx_hash_digest(h_second, False) h_second_dig = tx_hash_digest(h_second, False)
if h_first_dig != h_second_dig: if h_first_dig != h_second_dig:
raise ValueError('Transaction has changed during signing') raise ValueError('Transaction has changed during signing')
# compute the signature from the tx digest
h_sign_dig = tx_hash_digest(h_sign, True) h_sign_dig = tx_hash_digest(h_sign, True)
signature = ecdsa_sign(key_sign, h_sign_dig) signature = ecdsa_sign(key_sign, h_sign_dig)
txi_sign.script_sig = input_derive_script_post_sign(
txi_sign, key_sign_pub, signature)
# TODO: serialize the whole input at once, including the script_sig
txi_sign_w = BufferWriter()
write_tx_input(txi_sign_w, txi_sign)
txi_sign_b = txi_sign_w.getvalue()
tx_ser.signature_index = i_sign tx_ser.signature_index = i_sign
tx_ser.signature = signature tx_ser.signature = signature
tx_ser.serialized_tx = txi_sign_b
await send_serialized_tx(tx_ser)
del tx_ser.signature_index # serialize input with correct signature
del tx_ser.signature txi_sign.script_sig = input_derive_script_post_sign(
txi_sign, key_sign_pub, signature)
txi_sign_w = BufferWriter()
if i_sign == 0:
write_tx_header(txi_sign_w, tx_version, tx_inputs_count)
write_tx_input(txi_sign_w, txi_sign)
tx_ser.serialized_tx = txi_sign_w.getvalue()
tx_req.serialized = tx_ser
for o in range(tx.outputs_count): for o in range(tx.outputs_count):
# STAGE_REQUEST_5_OUTPUT # STAGE_REQUEST_5_OUTPUT
txo = await request_tx_output(o) txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
# serialize output
w_txo_bin = BufferWriter() w_txo_bin = BufferWriter()
if o == 0:
write_tx_middle(w_txo_bin, tx_outputs_count)
write_tx_output(w_txo_bin, txo_bin) write_tx_output(w_txo_bin, txo_bin)
if o == tx_outputs_count:
write_tx_footer(w_txo_bin, tx_lock_time, False)
tx_ser.signature_index = None
tx_ser.signature = None
tx_ser.serialized_tx = w_txo_bin.getvalue() tx_ser.serialized_tx = w_txo_bin.getvalue()
await send_serialized_tx(tx_ser)
await request_tx_finish() tx_req.serialized = tx_ser
await request_tx_finish(tx_req)
async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int:
total_out = 0 # sum of output amounts total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META # STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(prev_hash) tx = await request_tx_meta(prev_hash)
tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
txh = HashWriter(sha256) txh = HashWriter(sha256)
write_tx_header(txh, tx.version, tx.inputs_count) write_tx_header(txh, tx_version, tx_inputs_count)
for i in range(tx.inputs_count):
for i in range(tx_inputs_count):
# STAGE_REQUEST_2_PREV_INPUT # STAGE_REQUEST_2_PREV_INPUT
txi = await request_tx_input(i, prev_hash) txi = await request_tx_input(tx_req, i, prev_hash)
write_tx_input(txh, txi) write_tx_input(txh, txi)
write_tx_middle(txh, tx.outputs_count) write_tx_middle(txh, tx_outputs_count)
for o in range(tx.outputs_count):
for o in range(tx_outputs_count):
# STAGE_REQUEST_2_PREV_OUTPUT # STAGE_REQUEST_2_PREV_OUTPUT
txo_bin = await request_tx_output(o, prev_hash) txo_bin = await request_tx_output(tx_req, o, prev_hash)
write_tx_output(txh, txo_bin) write_tx_output(txh, txo_bin)
if o == prev_index: if o == prev_index:
total_out += txo_bin.value total_out += txo_bin.value
write_tx_footer(txh, tx.locktime, False) write_tx_footer(txh, tx_lock_time, False)
if tx_hash_digest(txh, True) != prev_hash: if tx_hash_digest(txh, True) != prev_hash:
raise ValueError('Encountered invalid prev_hash') raise ValueError('Encountered invalid prev_hash')
return total_out return total_out