You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
489 lines
14 KiB
489 lines
14 KiB
from trezor.crypto.hashlib import sha256, ripemd160
|
|
from trezor.crypto.curve import secp256k1
|
|
from trezor.crypto import base58
|
|
|
|
from . import coins
|
|
|
|
from trezor.messages.CoinType import CoinType
|
|
from trezor.messages.SignTx import SignTx
|
|
from trezor.messages.TxOutputType import TxOutputType
|
|
from trezor.messages.TxOutputBinType import TxOutputBinType
|
|
from trezor.messages.TxInputType import TxInputType
|
|
from trezor.messages.TxRequest import TxRequest
|
|
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
|
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
|
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
|
from trezor.messages import OutputScriptType, InputScriptType
|
|
|
|
|
|
# Machine instructions
|
|
# ===
|
|
|
|
|
|
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
|
tx_req.type = TXMETA
|
|
tx_req.details.tx_hash = tx_hash
|
|
tx_req.details.request_index = None
|
|
ack = yield tx_req
|
|
tx_req.serialized = None
|
|
return ack.tx
|
|
|
|
|
|
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|
tx_req.type = TXINPUT
|
|
tx_req.details.request_index = i
|
|
tx_req.details.tx_hash = tx_hash
|
|
ack = yield tx_req
|
|
tx_req.serialized = None
|
|
return ack.tx.inputs[0]
|
|
|
|
|
|
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|
tx_req.type = TXOUTPUT
|
|
tx_req.details.request_index = i
|
|
tx_req.details.tx_hash = tx_hash
|
|
ack = yield tx_req
|
|
tx_req.serialized = None
|
|
if tx_hash is None:
|
|
return ack.outputs[0]
|
|
else:
|
|
return ack.bin_outputs[0]
|
|
|
|
|
|
def request_tx_finish(tx_req: TxRequest):
|
|
tx_req.type = TXFINISHED
|
|
tx_req.details = None
|
|
yield tx_req
|
|
tx_req.serialized = None
|
|
|
|
|
|
# Transaction signing
|
|
# ===
|
|
|
|
|
|
async def sign_tx(tx: SignTx, root):
|
|
tx_version = getattr(tx, 'version', 0)
|
|
tx_lock_time = getattr(tx, 'lock_time', 1)
|
|
tx_inputs_count = getattr(tx, 'inputs_count', 0)
|
|
tx_outputs_count = getattr(tx, 'outputs_count', 0)
|
|
coin_name = getattr(tx, 'coin_name', 'Bitcoin')
|
|
|
|
coin = coins.by_name(coin_name)
|
|
|
|
# Phase 1
|
|
# - check inputs, previous transactions, and outputs
|
|
# - ask for confirmations
|
|
# - check fee
|
|
|
|
total_in = 0 # sum of input amounts
|
|
total_out = 0 # sum of output amounts
|
|
change_out = 0 # change output amount
|
|
|
|
# h_first is used to make sure the inputs and outputs streamed in Phase 1
|
|
# are the same as in Phase 2. it is thus not required to fully hash the
|
|
# tx, as the SignTx info is streamed only once
|
|
h_first = HashWriter(sha256) # not a real tx hash
|
|
|
|
txo_bin = TxOutputBinType()
|
|
tx_req = TxRequest()
|
|
tx_req.details = TxRequestDetailsType()
|
|
|
|
for i in range(tx_inputs_count):
|
|
# STAGE_REQUEST_1_INPUT
|
|
txi = await request_tx_input(tx_req, i)
|
|
write_tx_input(h_first, txi)
|
|
total_in += await get_prevtx_output_value(
|
|
tx_req, txi.prev_hash, txi.prev_index)
|
|
|
|
for o in range(tx_outputs_count):
|
|
# STAGE_REQUEST_3_OUTPUT
|
|
txo = await request_tx_output(tx_req, o)
|
|
if output_is_change(txo):
|
|
if change_out != 0:
|
|
raise ValueError('Only one change output is valid')
|
|
change_out = txo.amount
|
|
txo_bin.amount = txo.amount
|
|
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
|
write_tx_output(h_first, txo_bin)
|
|
total_out += txo_bin.amount
|
|
# TODO: display output
|
|
# TODO: confirm output
|
|
|
|
# TODO: check funds and tx fee
|
|
# TODO: ask for confirmation
|
|
|
|
# Phase 2
|
|
# - sign inputs
|
|
# - check that nothing changed
|
|
|
|
tx_ser = TxRequestSerializedType()
|
|
|
|
for i_sign in range(tx.inputs_count):
|
|
# hash of what we are signing with this input
|
|
h_sign = HashWriter(sha256)
|
|
# same as h_first, checked at the end of this iteration
|
|
h_second = HashWriter(sha256)
|
|
|
|
txi_sign = None
|
|
key_sign = None
|
|
key_sign_pub = None
|
|
|
|
write_tx_header(h_sign, tx_version, tx_inputs_count)
|
|
|
|
for i in range(tx.inputs_count):
|
|
# STAGE_REQUEST_4_INPUT
|
|
txi = await request_tx_input(tx_req, i)
|
|
write_tx_input(h_second, txi)
|
|
if i == i_sign:
|
|
txi_sign = txi
|
|
key_sign = node_derive(root, txi.address_n)
|
|
key_sign_pub = key_sign.public_key()
|
|
txi.script_sig = input_derive_script_pre_sign(
|
|
txi, key_sign_pub)
|
|
else:
|
|
txi.script_sig = bytes()
|
|
write_tx_input(h_sign, txi)
|
|
|
|
write_tx_middle(h_sign, tx_outputs_count)
|
|
|
|
for o in range(tx.outputs_count):
|
|
# STAGE_REQUEST_4_OUTPUT
|
|
txo = await request_tx_output(tx_req, o)
|
|
txo_bin.amount = txo.amount
|
|
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
|
write_tx_output(h_second, txo_bin)
|
|
write_tx_output(h_sign, txo_bin)
|
|
|
|
write_tx_footer(h_sign, tx_lock_time, True)
|
|
|
|
# check the control digests
|
|
h_first_dig = tx_hash_digest(h_first, False)
|
|
h_second_dig = tx_hash_digest(h_second, False)
|
|
if h_first_dig != h_second_dig:
|
|
raise ValueError('Transaction has changed during signing')
|
|
|
|
# compute the signature from the tx digest
|
|
h_sign_dig = tx_hash_digest(h_sign, True)
|
|
signature = ecdsa_sign(key_sign, h_sign_dig)
|
|
tx_ser.signature_index = i_sign
|
|
tx_ser.signature = signature
|
|
|
|
# serialize input with correct signature
|
|
txi_sign.script_sig = input_derive_script_post_sign(
|
|
txi_sign, key_sign_pub, signature)
|
|
txi_sign_w = BufferWriter()
|
|
if i_sign == 0:
|
|
write_tx_header(txi_sign_w, tx_version, tx_inputs_count)
|
|
write_tx_input(txi_sign_w, txi_sign)
|
|
tx_ser.serialized_tx = txi_sign_w.getvalue()
|
|
|
|
tx_req.serialized = tx_ser
|
|
|
|
for o in range(tx.outputs_count):
|
|
# STAGE_REQUEST_5_OUTPUT
|
|
txo = await request_tx_output(tx_req, o)
|
|
txo_bin.amount = txo.amount
|
|
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
|
|
|
# serialize output
|
|
w_txo_bin = BufferWriter()
|
|
if o == 0:
|
|
write_tx_middle(w_txo_bin, tx_outputs_count)
|
|
write_tx_output(w_txo_bin, txo_bin)
|
|
if o == tx_outputs_count:
|
|
write_tx_footer(w_txo_bin, tx_lock_time, False)
|
|
tx_ser.signature_index = None
|
|
tx_ser.signature = None
|
|
tx_ser.serialized_tx = w_txo_bin.getvalue()
|
|
|
|
tx_req.serialized = tx_ser
|
|
|
|
await request_tx_finish(tx_req)
|
|
|
|
|
|
async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int:
|
|
total_out = 0 # sum of output amounts
|
|
|
|
# STAGE_REQUEST_2_PREV_META
|
|
tx = await request_tx_meta(prev_hash)
|
|
|
|
tx_version = getattr(tx, 'version', 0)
|
|
tx_lock_time = getattr(tx, 'lock_time', 1)
|
|
tx_inputs_count = getattr(tx, 'inputs_count', 0)
|
|
tx_outputs_count = getattr(tx, 'outputs_count', 0)
|
|
|
|
txh = HashWriter(sha256)
|
|
|
|
write_tx_header(txh, tx_version, tx_inputs_count)
|
|
|
|
for i in range(tx_inputs_count):
|
|
# STAGE_REQUEST_2_PREV_INPUT
|
|
txi = await request_tx_input(tx_req, i, prev_hash)
|
|
write_tx_input(txh, txi)
|
|
|
|
write_tx_middle(txh, tx_outputs_count)
|
|
|
|
for o in range(tx_outputs_count):
|
|
# STAGE_REQUEST_2_PREV_OUTPUT
|
|
txo_bin = await request_tx_output(tx_req, o, prev_hash)
|
|
write_tx_output(txh, txo_bin)
|
|
if o == prev_index:
|
|
total_out += txo_bin.value
|
|
|
|
write_tx_footer(txh, tx_lock_time, False)
|
|
|
|
if tx_hash_digest(txh, True) != prev_hash:
|
|
raise ValueError('Encountered invalid prev_hash')
|
|
return total_out
|
|
|
|
|
|
def tx_hash_digest(w, double: bool):
|
|
d = w.getvalue()
|
|
if double:
|
|
d = sha256(d).digest()
|
|
return d
|
|
|
|
|
|
# TX Outputs
|
|
# ===
|
|
|
|
|
|
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
|
if o.script_type == OutputScriptType.PAYTOADDRESS:
|
|
return script_paytoaddress_new(
|
|
output_paytoaddress_extract_raw_address(o, coin, root))
|
|
else:
|
|
raise ValueError('Invalid output script type')
|
|
return
|
|
|
|
|
|
def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root) -> bytes:
|
|
o_address_n = getattr(o, 'address_n', None)
|
|
o_address = getattr(o, 'address', None)
|
|
# TODO: dont encode/decode more then necessary
|
|
# TODO: detect correct address type
|
|
if o_address_n is not None:
|
|
n = node_derive(root, o_address_n)
|
|
raw_address = base58.decode_check(n.address())
|
|
elif o_address:
|
|
raw_address = base58.decode_check(o_address)
|
|
else:
|
|
raise ValueError('Missing address')
|
|
if raw_address[0] != coin.address_type:
|
|
raise ValueError('Invalid address type')
|
|
return raw_address
|
|
|
|
|
|
def output_is_change(output: TxOutputType):
|
|
address_n = getattr(output, 'address_n', None)
|
|
return bool(address_n)
|
|
|
|
|
|
# Tx Inputs
|
|
# ===
|
|
|
|
|
|
def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
|
|
if i.script_type == InputScriptType.SPENDADDRESS:
|
|
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
|
|
else:
|
|
raise ValueError('Unknown input script type')
|
|
|
|
|
|
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
|
|
if i.script_type == InputScriptType.SPENDADDRESS:
|
|
return script_spendaddress_new(pubkey, signature)
|
|
else:
|
|
raise ValueError('Unknown input script type')
|
|
|
|
|
|
def node_derive(root, address_n: list):
|
|
node = root.clone()
|
|
node.derive_path(address_n)
|
|
return node
|
|
|
|
|
|
def ecdsa_hash_pubkey(pubkey: bytes) -> bytes:
|
|
if pubkey[0] == 0x04:
|
|
assert len(pubkey) == 65 # uncompressed format
|
|
elif pubkey[0] == 0x00:
|
|
assert len(pubkey) == 1 # point at infinity
|
|
else:
|
|
assert len(pubkey) == 33 # compresssed format
|
|
h = sha256(pubkey).digest()
|
|
h = ripemd160(h).digest()
|
|
return h
|
|
|
|
|
|
def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
|
|
return secp256k1.sign(privkey, digest)
|
|
|
|
|
|
# TX Scripts
|
|
# ===
|
|
|
|
|
|
def script_paytoaddress_new(raw_address: bytes) -> bytearray:
|
|
s = bytearray(25)
|
|
s[0] = 0x76 # OP_DUP
|
|
s[1] = 0xA9 # OP_HASH_160
|
|
s[2] = 0x14 # pushing 20 bytes
|
|
s[3:23] = raw_address[1:] # TODO: do this properly
|
|
s[23] = 0x88 # OP_EQUALVERIFY
|
|
s[24] = 0xAC # OP_CHECKSIG
|
|
return s
|
|
|
|
|
|
def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
|
|
w = BufferWriter()
|
|
write_op_push(w, len(signature) + 1)
|
|
write_bytes(w, signature)
|
|
w.writebyte(0x01)
|
|
write_op_push(w, len(pubkey))
|
|
write_bytes(w, pubkey)
|
|
return w.getvalue()
|
|
|
|
|
|
# TX Serialization
|
|
# ===
|
|
|
|
|
|
def write_tx_header(w, version: int, inputs_count: int):
|
|
write_uint32(w, version)
|
|
write_varint(w, inputs_count)
|
|
|
|
|
|
def write_tx_input(w, i: TxInputType):
|
|
write_bytes_rev(w, i.prev_hash)
|
|
write_uint32(w, i.prev_index)
|
|
write_varint(w, len(i.script_sig))
|
|
write_bytes(w, i.script_sig)
|
|
write_uint32(w, i.sequence)
|
|
|
|
|
|
def write_tx_middle(w, outputs_count: int):
|
|
write_varint(w, outputs_count)
|
|
|
|
|
|
def write_tx_output(w, o: TxOutputBinType):
|
|
write_uint64(w, o.amount)
|
|
write_varint(w, len(o.script_pubkey))
|
|
write_bytes(w, o.script_pubkey)
|
|
|
|
|
|
def write_tx_footer(w, locktime: int, add_hash_type: bool):
|
|
write_uint32(w, locktime)
|
|
if add_hash_type:
|
|
write_uint32(w, 1)
|
|
|
|
|
|
def write_op_push(w, n: int):
|
|
wb = w.writebyte
|
|
if n < 0x4C:
|
|
wb(n & 0xFF)
|
|
elif n < 0xFF:
|
|
wb(0x4C)
|
|
wb(n & 0xFF)
|
|
elif n < 0xFFFF:
|
|
wb(0x4D)
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
else:
|
|
wb(0x4E)
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
wb((n >> 16) & 0xFF)
|
|
wb((n >> 24) & 0xFF)
|
|
|
|
|
|
# Buffer IO & Serialization
|
|
# ===
|
|
|
|
|
|
def write_varint(w, n: int):
|
|
wb = w.writebyte
|
|
if n < 253:
|
|
wb(n & 0xFF)
|
|
elif n < 65536:
|
|
wb(253)
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
else:
|
|
wb(254)
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
wb((n >> 16) & 0xFF)
|
|
wb((n >> 24) & 0xFF)
|
|
|
|
|
|
def write_uint32(w, n: int):
|
|
wb = w.writebyte
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
wb((n >> 16) & 0xFF)
|
|
wb((n >> 24) & 0xFF)
|
|
|
|
|
|
def write_uint64(w, n: int):
|
|
wb = w.writebyte
|
|
wb(n & 0xFF)
|
|
wb((n >> 8) & 0xFF)
|
|
wb((n >> 16) & 0xFF)
|
|
wb((n >> 24) & 0xFF)
|
|
wb((n >> 32) & 0xFF)
|
|
wb((n >> 40) & 0xFF)
|
|
wb((n >> 48) & 0xFF)
|
|
wb((n >> 56) & 0xFF)
|
|
|
|
|
|
def write_bytes(w, buf: bytearray):
|
|
w.write(buf)
|
|
|
|
|
|
def write_bytes_rev(w, buf: bytearray):
|
|
w.write(bytearray(reversed(buf)))
|
|
|
|
|
|
class BufferWriter:
|
|
|
|
def __init__(self, buf: bytearray=None, ofs: int=0):
|
|
# TODO: re-think the use of bytearrays, buffers, and other byte IO
|
|
# i think we should just pass a pre-allocation size here, allocate the
|
|
# bytearray and then trim it to zero. in this case, write() would
|
|
# correspond to extend(), and writebyte() to append(). of course, the
|
|
# the use-case of non-destructively writing to existing bytearray still
|
|
# exists.
|
|
if buf is None:
|
|
buf = bytearray()
|
|
self.buf = buf
|
|
self.ofs = ofs
|
|
|
|
def write(self, buf: bytearray):
|
|
n = len(buf)
|
|
self.buf[self.ofs:self.ofs + n] = buf
|
|
self.ofs += n
|
|
|
|
def writebyte(self, b: int):
|
|
self.buf[self.ofs] = b
|
|
self.ofs += 1
|
|
|
|
def getvalue(self) -> bytearray:
|
|
return self.buf
|
|
|
|
|
|
class HashWriter:
|
|
|
|
def __init__(self, hashfunc):
|
|
self.ctx = hashfunc()
|
|
self.buf = bytearray(1) # used in writebyte()
|
|
|
|
def write(self, buf: bytearray):
|
|
self.ctx.update(buf)
|
|
|
|
def writebyte(self, b: int):
|
|
self.buf[0] = b
|
|
self.ctx.update(self.buf)
|
|
|
|
def getvalue(self) -> bytes:
|
|
return self.ctx.digest()
|