You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/src/apps/common/signtx.py

489 lines
14 KiB

from trezor.crypto.hashlib import sha256, ripemd160
from trezor.crypto.curve import secp256k1
from trezor.crypto import base58
from . import coins
from trezor.messages.CoinType import CoinType
from trezor.messages.SignTx import SignTx
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType, InputScriptType
# Machine instructions
# ===
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
ack = yield tx_req
tx_req.serialized = None
return ack.tx
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
return ack.tx.inputs[0]
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return ack.outputs[0]
else:
return ack.bin_outputs[0]
def request_tx_finish(tx_req: TxRequest):
tx_req.type = TXFINISHED
tx_req.details = None
yield tx_req
tx_req.serialized = None
# Transaction signing
# ===
async def sign_tx(tx: SignTx, root):
tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
coin_name = getattr(tx, 'coin_name', 'Bitcoin')
coin = coins.by_name(coin_name)
# Phase 1
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
total_in = 0 # sum of input amounts
total_out = 0 # sum of output amounts
change_out = 0 # change output amount
# h_first is used to make sure the inputs and outputs streamed in Phase 1
# are the same as in Phase 2. it is thus not required to fully hash the
# tx, as the SignTx info is streamed only once
h_first = HashWriter(sha256) # not a real tx hash
txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
for i in range(tx_inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input(h_first, txi)
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
for o in range(tx_outputs_count):
# STAGE_REQUEST_3_OUTPUT
txo = await request_tx_output(tx_req, o)
if output_is_change(txo):
if change_out != 0:
raise ValueError('Only one change output is valid')
change_out = txo.amount
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_first, txo_bin)
total_out += txo_bin.amount
# TODO: display output
# TODO: confirm output
# TODO: check funds and tx fee
# TODO: ask for confirmation
# Phase 2
# - sign inputs
# - check that nothing changed
tx_ser = TxRequestSerializedType()
for i_sign in range(tx.inputs_count):
# hash of what we are signing with this input
h_sign = HashWriter(sha256)
# same as h_first, checked at the end of this iteration
h_second = HashWriter(sha256)
txi_sign = None
key_sign = None
key_sign_pub = None
write_tx_header(h_sign, tx_version, tx_inputs_count)
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input(h_second, txi)
if i == i_sign:
txi_sign = txi
key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key()
txi.script_sig = input_derive_script_pre_sign(
txi, key_sign_pub)
else:
txi.script_sig = bytes()
write_tx_input(h_sign, txi)
write_tx_middle(h_sign, tx_outputs_count)
for o in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_second, txo_bin)
write_tx_output(h_sign, txo_bin)
write_tx_footer(h_sign, tx_lock_time, True)
# check the control digests
h_first_dig = tx_hash_digest(h_first, False)
h_second_dig = tx_hash_digest(h_second, False)
if h_first_dig != h_second_dig:
raise ValueError('Transaction has changed during signing')
# compute the signature from the tx digest
h_sign_dig = tx_hash_digest(h_sign, True)
signature = ecdsa_sign(key_sign, h_sign_dig)
tx_ser.signature_index = i_sign
tx_ser.signature = signature
# serialize input with correct signature
txi_sign.script_sig = input_derive_script_post_sign(
txi_sign, key_sign_pub, signature)
txi_sign_w = BufferWriter()
if i_sign == 0:
write_tx_header(txi_sign_w, tx_version, tx_inputs_count)
write_tx_input(txi_sign_w, txi_sign)
tx_ser.serialized_tx = txi_sign_w.getvalue()
tx_req.serialized = tx_ser
for o in range(tx.outputs_count):
# STAGE_REQUEST_5_OUTPUT
txo = await request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
# serialize output
w_txo_bin = BufferWriter()
if o == 0:
write_tx_middle(w_txo_bin, tx_outputs_count)
write_tx_output(w_txo_bin, txo_bin)
if o == tx_outputs_count:
write_tx_footer(w_txo_bin, tx_lock_time, False)
tx_ser.signature_index = None
tx_ser.signature = None
tx_ser.serialized_tx = w_txo_bin.getvalue()
tx_req.serialized = tx_ser
await request_tx_finish(tx_req)
async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int:
total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(prev_hash)
tx_version = getattr(tx, 'version', 0)
tx_lock_time = getattr(tx, 'lock_time', 1)
tx_inputs_count = getattr(tx, 'inputs_count', 0)
tx_outputs_count = getattr(tx, 'outputs_count', 0)
txh = HashWriter(sha256)
write_tx_header(txh, tx_version, tx_inputs_count)
for i in range(tx_inputs_count):
# STAGE_REQUEST_2_PREV_INPUT
txi = await request_tx_input(tx_req, i, prev_hash)
write_tx_input(txh, txi)
write_tx_middle(txh, tx_outputs_count)
for o in range(tx_outputs_count):
# STAGE_REQUEST_2_PREV_OUTPUT
txo_bin = await request_tx_output(tx_req, o, prev_hash)
write_tx_output(txh, txo_bin)
if o == prev_index:
total_out += txo_bin.value
write_tx_footer(txh, tx_lock_time, False)
if tx_hash_digest(txh, True) != prev_hash:
raise ValueError('Encountered invalid prev_hash')
return total_out
def tx_hash_digest(w, double: bool):
d = w.getvalue()
if double:
d = sha256(d).digest()
return d
# TX Outputs
# ===
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if o.script_type == OutputScriptType.PAYTOADDRESS:
return script_paytoaddress_new(
output_paytoaddress_extract_raw_address(o, coin, root))
else:
raise ValueError('Invalid output script type')
return
def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root) -> bytes:
o_address_n = getattr(o, 'address_n', None)
o_address = getattr(o, 'address', None)
# TODO: dont encode/decode more then necessary
# TODO: detect correct address type
if o_address_n is not None:
n = node_derive(root, o_address_n)
raw_address = base58.decode_check(n.address())
elif o_address:
raw_address = base58.decode_check(o_address)
else:
raise ValueError('Missing address')
if raw_address[0] != coin.address_type:
raise ValueError('Invalid address type')
return raw_address
def output_is_change(output: TxOutputType):
address_n = getattr(output, 'address_n', None)
return bool(address_n)
# Tx Inputs
# ===
def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else:
raise ValueError('Unknown input script type')
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
raise ValueError('Unknown input script type')
def node_derive(root, address_n: list):
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04:
assert len(pubkey) == 65 # uncompressed format
elif pubkey[0] == 0x00:
assert len(pubkey) == 1 # point at infinity
else:
assert len(pubkey) == 33 # compresssed format
h = sha256(pubkey).digest()
h = ripemd160(h).digest()
return h
def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
return secp256k1.sign(privkey, digest)
# TX Scripts
# ===
def script_paytoaddress_new(raw_address: bytes) -> bytearray:
s = bytearray(25)
s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160
s[2] = 0x14 # pushing 20 bytes
s[3:23] = raw_address[1:] # TODO: do this properly
s[23] = 0x88 # OP_EQUALVERIFY
s[24] = 0xAC # OP_CHECKSIG
return s
def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
w = BufferWriter()
write_op_push(w, len(signature) + 1)
write_bytes(w, signature)
w.writebyte(0x01)
write_op_push(w, len(pubkey))
write_bytes(w, pubkey)
return w.getvalue()
# TX Serialization
# ===
def write_tx_header(w, version: int, inputs_count: int):
write_uint32(w, version)
write_varint(w, inputs_count)
def write_tx_input(w, i: TxInputType):
write_bytes_rev(w, i.prev_hash)
write_uint32(w, i.prev_index)
write_varint(w, len(i.script_sig))
write_bytes(w, i.script_sig)
write_uint32(w, i.sequence)
def write_tx_middle(w, outputs_count: int):
write_varint(w, outputs_count)
def write_tx_output(w, o: TxOutputBinType):
write_uint64(w, o.amount)
write_varint(w, len(o.script_pubkey))
write_bytes(w, o.script_pubkey)
def write_tx_footer(w, locktime: int, add_hash_type: bool):
write_uint32(w, locktime)
if add_hash_type:
write_uint32(w, 1)
def write_op_push(w, n: int):
wb = w.writebyte
if n < 0x4C:
wb(n & 0xFF)
elif n < 0xFF:
wb(0x4C)
wb(n & 0xFF)
elif n < 0xFFFF:
wb(0x4D)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
else:
wb(0x4E)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
wb((n >> 16) & 0xFF)
wb((n >> 24) & 0xFF)
# Buffer IO & Serialization
# ===
def write_varint(w, n: int):
wb = w.writebyte
if n < 253:
wb(n & 0xFF)
elif n < 65536:
wb(253)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
else:
wb(254)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
wb((n >> 16) & 0xFF)
wb((n >> 24) & 0xFF)
def write_uint32(w, n: int):
wb = w.writebyte
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
wb((n >> 16) & 0xFF)
wb((n >> 24) & 0xFF)
def write_uint64(w, n: int):
wb = w.writebyte
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
wb((n >> 16) & 0xFF)
wb((n >> 24) & 0xFF)
wb((n >> 32) & 0xFF)
wb((n >> 40) & 0xFF)
wb((n >> 48) & 0xFF)
wb((n >> 56) & 0xFF)
def write_bytes(w, buf: bytearray):
w.write(buf)
def write_bytes_rev(w, buf: bytearray):
w.write(bytearray(reversed(buf)))
class BufferWriter:
def __init__(self, buf: bytearray=None, ofs: int=0):
# TODO: re-think the use of bytearrays, buffers, and other byte IO
# i think we should just pass a pre-allocation size here, allocate the
# bytearray and then trim it to zero. in this case, write() would
# correspond to extend(), and writebyte() to append(). of course, the
# the use-case of non-destructively writing to existing bytearray still
# exists.
if buf is None:
buf = bytearray()
self.buf = buf
self.ofs = ofs
def write(self, buf: bytearray):
n = len(buf)
self.buf[self.ofs:self.ofs + n] = buf
self.ofs += n
def writebyte(self, b: int):
self.buf[self.ofs] = b
self.ofs += 1
def getvalue(self) -> bytearray:
return self.buf
class HashWriter:
def __init__(self, hashfunc):
self.ctx = hashfunc()
self.buf = bytearray(1) # used in writebyte()
def write(self, buf: bytearray):
self.ctx.update(buf)
def writebyte(self, b: int):
self.buf[0] = b
self.ctx.update(self.buf)
def getvalue(self) -> bytes:
return self.ctx.digest()