1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

signtx: WIP

This commit is contained in:
Jan Pochyla 2016-11-03 18:56:21 +01:00
parent 6a98aff8bb
commit 3b742aa5dc

View File

@ -1,9 +1,8 @@
from trezor.crypto.hashlib import sha256
from trezor.crypto import HDNode
from trezor.utils import memcpy, memcpy_rev
from trezor.crypto.hashlib import sha256, ripemd160
from trezor.crypto.curve import secp256k1
from trezor.crypto import HDNode, base58
from . import coins
from . import seed
from trezor.messages.CoinType import CoinType
from trezor.messages.SignTx import SignTx
@ -108,17 +107,22 @@ async def sign_tx(tx: SignTx, root: HDNode):
h_sign = tx_hash_init() # hash of what we are signing with this input
h_second = tx_hash_init() # should be the same as h_first
input_sign = None
key_sign = None
key_sign_pub = None
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
input = await request_tx_input(i)
tx_write_input(h_second, input)
if i == i_sign:
signing_key = node_derive(root, input.address_n)
signing_key_pub = signing_key.public_key()
input.script_sig = input_derive_scriptsig_for_signing(
input, signing_key_pub)
key_sign = node_derive(root, input.address_n)
key_sign_pub = key_sign.public_key()
script_sig = input_derive_script_pre_sign(input, key_sign_pub)
input_sign = input
else:
input.script_sig = bytes()
script_sig = bytes()
input.script_sig = script_sig
tx_write_input(h_sign, input)
for o in range(tx.outputs_count):
@ -131,17 +135,30 @@ async def sign_tx(tx: SignTx, root: HDNode):
if h_first_dig != tx_hash_digest(h_second):
raise ValueError('Transaction has changed during signing')
sig = sign(signing_key, tx_hash_digest(h_sign))
# TODO: serialize scriptsig again
# TODO: serialize input
serialized = xxx
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign))
script_sig = input_derive_script_post_sign(
input, key_sign_pub, signature)
input_sign.script_sig = script_sig
# TODO: serialize the whole input at once, including the script_sig
input_sign_w = BufferWriter(bytearray(), 0)
tx_write_input(input_sign_w, input_sign)
input_sign_b = input_sign_w.getvalue()
serialized = TxRequestSerializedType(
signature_index=i_sign, signature=signature, serialized_tx=input_sign_b)
await send_serialized_tx(serialized)
for o in range(tx.outputs_count):
# STAGE_REQUEST_5_OUTPUT
output = await request_tx_output(o)
outputbin = output_compile(output, coin, root)
serialized = xxx
outputbin_w = BufferWriter(bytearray(), 0)
tx_write_input(outputbin_w, outputbin)
outputbin_b = outputbin_w.getvalue()
serialized = TxRequestSerializedType(serialized_tx=outputbin_b)
await send_serialized_tx(serialized)
await request_tx_finish()
@ -152,21 +169,21 @@ async def get_prevtx_output_value(prev_hash: bytes, prev_index: int) -> int:
total_in = 0
# STAGE_REQUEST_2_PREV_META
tx = await TxRequest(type=TXMETA, hash=prev_hash)
tx = await request_tx_meta(prev_hash)
h = tx_hash_init()
tx_write_header(h, tx.version, tx.inputs_count)
for i in range(tx.inputs_count):
# STAGE_REQUEST_2_PREV_INPUT
input = await TxRequest(type=TXINPUT, hash=prev_hash, index=i)
input = await request_tx_input(i, prev_hash)
tx_write_input(h, input)
tx_write_middle(h, tx.outputs_count)
for o in range(tx.outputs_count):
# STAGE_REQUEST_2_PREV_OUTPUT
outputbin = await TxRequest(type=TXOUTPUT, hash=prev_hash, index=o)
outputbin = await request_tx_output(o, prev_hash)
tx_write_output(h, outputbin)
if o == prev_index:
total_in += outputbin.value
@ -198,29 +215,31 @@ def tx_hash_digest(w):
def output_compile(output: TxOutputType, coin: CoinType, root: HDNode) -> TxOutputBinType:
bin = TxOutputBinType()
bin.amount = output.amount
if output.script_type == OutputScriptType.PAYTOADDRESS:
raw_address = output_paytoaddress_extract_raw_address(output, root)
if raw_address[0] != coin.address_type:
raise ValueError('Invalid address type')
bin.script_pubkey = script_paytoaddress_new(raw_address)
else:
# TODO: other output script types
raise ValueError('Unknown output script type')
bin.script_pubkey = output_derive_script(output, coin, root)
return bin
def output_paytoaddress_extract_raw_address(output: TxOutputType, root: HDNode) -> bytes:
output_address_n = getattr(output, 'address_n', None)
output_address = getattr(output, 'address_n', None)
if output_address_n:
node = node_derive(root, output_address_n)
def output_derive_script(output: TxOutputType, coin: CoinType, root: HDNode) -> bytes:
if output.script_type == OutputScriptType.PAYTOADDRESS:
raw_address = output_paytoaddress_extract_raw_address(output, root)
if raw_address[0] != coin.address_type: # TODO: do this properly
raise ValueError('Invalid address type')
return script_paytoaddress_new(raw_address)
else:
# TODO: other output script types
raise ValueError('Unknown output script type')
return
def output_paytoaddress_extract_raw_address(o: TxOutputType, root: HDNode) -> bytes:
o_address_n = getattr(o, 'address_n', None)
o_address = getattr(o, 'address', None)
if o_address_n:
node = node_derive(root, o_address_n)
# TODO: dont encode and decode again
raw_address = address_decode(node.address())
elif output_address:
raw_address = address_decode(output_address)
raw_address = base58.decode_check(node.address())
elif o_address:
raw_address = base58.decode_check(o_address)
else:
raise ValueError('Missing address')
return raw_address
@ -231,7 +250,7 @@ def script_paytoaddress_new(raw_address: bytes) -> bytearray:
s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160
s[2] = 0x14 # pushing 20 bytes
s[3:23] = raw_address
s[3:23] = raw_address[1:] # TODO: do this properly
s[23] = 0x88 # OP_EQUALVERIFY
s[24] = 0xAC # OP_CHECKSIG
return s
@ -246,34 +265,31 @@ def output_is_change(output: TxOutputType):
# ===
def input_derive_scriptsig_for_signing(input: TxInputType, pubkey: bytes) -> bytes:
def input_derive_script_pre_sign(input: TxInputType, pubkey: bytes) -> bytes:
if input.script_type == InputScriptType.SPENDADDRESS:
pubkeyhash = xxx
return script_spendaddress_new(pubkeyhash)
return script_paytoaddress_new(ecdsa_get_pubkeyhash(pubkey))
else:
# TODO: other input script types
raise ValueError('Unknown input script type')
def script_spendaddress_new(pubkeyhash: bytes) -> bytearray:
def input_derive_script_post_sign(input: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
if input.script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
# TODO: other input script types
raise ValueError('Unknown input script type')
def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
s = bytearray(25)
s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160
s[2] = 0x14 # pushing 20 bytes
s[3:23] = pubkeyhash
s[23] = 0x88 # OP_EQUALVERIFY
s[24] = 0xAC # OP_CHECKSIG
return s
async def sign(privkey: bytes, digest: bytes) -> bytes:
# TODO: ecdsa secp256k1 digest sign
return b''
# Addresses, HDNodes
# ===
w = BufferWriter(s, 0)
write_op_push(w, len(signature) + 1)
write_bytes(w, signature)
w.writebyte(0x01)
write_op_push(w, len(pubkey))
write_bytes(w, pubkey)
return
def node_derive(root: HDNode, address_n: list) -> HDNode:
@ -283,9 +299,20 @@ def node_derive(root: HDNode, address_n: list) -> HDNode:
return node
def address_decode(address: str) -> bytes:
# TODO: decode the address from base58
return b''
def ecdsa_get_pubkeyhash(pubkey: bytes) -> bytes:
if pubkey[0] == 0x04:
assert len(pubkey) == 65 # uncompressed format
elif pubkey[0] == 0x00:
assert len(pubkey) == 1 # point at infinity
else:
assert len(pubkey) == 33 # compresssed format
h = sha256(pubkey).digest()
h = ripemd160(h).digest()
return h
async def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes:
return secp256k1.sign(privkey, digest)
# TX Serialization
@ -321,6 +348,25 @@ def tx_write_footer(w, locktime: int, add_hash_type: bool):
write_uint32(w, 1)
def write_op_push(w, n: int):
wb = w.writebyte
if n < 0x4C:
wb(n & 0xFF)
elif n < 0xFF:
wb(0x4C)
wb(n & 0xFF)
elif n < 0xFFFF:
wb(0x4D)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
else:
wb(0x4E)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
wb((n >> 16) & 0xFF)
wb((n >> 24) & 0xFF)
# Buffer IO & Serialization
# ===
@ -329,7 +375,7 @@ def write_varint(w, n: int):
wb = w.writebyte
if n < 253:
wb(n & 0xFF)
elif n < 0x10000:
elif n < 65536:
wb(253)
wb(n & 0xFF)
wb((n >> 8) & 0xFF)
@ -372,20 +418,25 @@ def write_bytes_rev(w, buf: bytearray):
class BufferWriter:
def __init__(self, buf: bytearray, ofs: int):
# TODO: re-think the use of bytearrays, buffers, and other byte IO
# i think we should just pass a pre-allocation size here, allocate the
# bytearray and then trim it to zero. in this case, write() would
# correspond to extend(), and writebyte() to append(). of course, the
# the use-case of non-destructively writing to existing bytearray still
# exists.
self.buf = buf
self.ofs = ofs
def write(self, buf):
def write(self, buf: bytearray):
n = len(buf)
w = memcpy(self.buf, self.ofs, buf, 0, n)
self.ofs += w
return w
self.buf[self.ofs:self.ofs + n] = buf
self.ofs += n
def writebyte(self, b):
def writebyte(self, b: int):
self.buf[self.ofs] = b
self.ofs += 1
def getvalue(self):
def getvalue(self) -> bytearray:
return self.buf
@ -393,12 +444,14 @@ class HashWriter:
def __init__(self, hashfunc):
self.ctx = hashfunc()
self.buf = bytearray(1) # used in writebyte()
def write(self, buf):
def write(self, buf: bytearray):
self.ctx.update(buf)
def writebyte(self, b):
self.ctx.update(bytes(b))
def writebyte(self, b: int):
self.buf[0] = b
self.ctx.update(self.buf)
def getvalue(self):
def getvalue(self) -> bytes:
return self.ctx.digest()