1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-13 19:18:56 +00:00

wallet/signing: segwit first test passing

This commit is contained in:
Tomas Susanka 2017-10-30 11:42:22 +01:00
parent bcef961059
commit e63d0adc23
4 changed files with 237 additions and 32 deletions

View File

@ -1,5 +1,10 @@
from trezor.crypto.hashlib import sha256
from trezor.messages.SignTx import SignTx
from trezor.messages import InputScriptType, FailureType
class Bip143Error(ValueError):
pass
class Bip143:
@ -10,7 +15,7 @@ class Bip143:
self.h_outputs = HashWriter(sha256)
def add_prevouts(self, txi: TxInputType):
write_bytes(self.h_prevouts, txi.prev_hash)
write_bytes_rev(self.h_prevouts, txi.prev_hash)
write_uint32(self.h_prevouts, txi.prev_index)
def get_prevouts_hash(self) -> bytes:
@ -28,17 +33,18 @@ class Bip143:
def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs, True)
def preimage(self, tx: SignTx, txi: TxInputType, script_code) -> bytes:
def preimage_hash(self, tx: SignTx, txi: TxInputType, pubkeyhash) -> bytes:
h_preimage = HashWriter(sha256)
write_uint32(h_preimage, tx.version) # nVersion
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # hashSequence
write_bytes(h_preimage, txi.prev_hash) # outpoint
write_bytes_rev(h_preimage, txi.prev_hash) # outpoint
write_uint32(h_preimage, txi.prev_index) # outpoint
script_code = self.derive_script_code(txi, pubkeyhash)
write_varint(h_preimage, len(script_code)) # scriptCode length
write_bytes(h_preimage, bytearray(script_code)) # scriptCode
write_bytes(h_preimage, script_code) # scriptCode
write_uint64(h_preimage, txi.amount) # amount
write_uint32(h_preimage, txi.sequence) # nSequence
@ -48,3 +54,19 @@ class Bip143:
write_uint32(h_preimage, 0x00000001) # nHashType todo
return get_tx_hash(h_preimage, True)
# this not redeemScript nor scriptPubKey
# for P2WPKH this is always 0x1976a914{20-byte-pubkey-hash}88ac
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
if txi.script_type == InputScriptType.SPENDP2SHWITNESS:
s = bytearray(25)
s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160
s[2] = 0x14 # pushing 20 bytes
s[3:23] = pubkeyhash
s[23] = 0x88 # OP_EQUALVERIFY
s[24] = 0xAC # OP_CHECKSIG
return s
else:
raise Bip143Error(FailureType.SyntaxError,
'Unknown input script type for bip143 script code')

View File

@ -4,14 +4,13 @@ from trezor.crypto import base58, der
from trezor.utils import ensure
from trezor.messages.CoinType import CoinType
from trezor.messages.SignTx import SignTx
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType, InputScriptType, FailureType
from trezor.messages import OutputScriptType
from apps.common import address_type
from apps.common import coins
@ -140,7 +139,7 @@ def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
async def check_tx_fee(tx: SignTx, root, segwit):
async def check_tx_fee(tx: SignTx, root):
coin = coins.by_name(tx.coin_name)
@ -157,17 +156,20 @@ async def check_tx_fee(tx: SignTx, root, segwit):
total_in = 0 # sum of input amounts
total_out = 0 # sum of output amounts
change_out = 0 # change output amount
segwit = {} # dict of booleans stating if input is segwit
for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input_check(h_first, txi)
if segwit:
if txi.script_type == InputScriptType.SPENDP2SHWITNESS:
segwit[i] = True
# Add I to segwit hash_prevouts, hash_sequence
bip143.add_prevouts(txi)
bip143.add_sequence(txi)
total_in += txi.amount
else:
segwit[i] = False
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
@ -203,16 +205,16 @@ async def check_tx_fee(tx: SignTx, root, segwit):
raise SigningError(FailureType.ActionCancelled,
'Total cancelled')
return h_first, tx_req, txo_bin, bip143
return h_first, tx_req, txo_bin, bip143, segwit
async def sign_tx(tx: SignTx, root, segwit=False):
async def sign_tx(tx: SignTx, root):
tx = sanitize_sign_tx(tx)
# Phase 1
h_first, tx_req, txo_bin, bip143 = await check_tx_fee(tx, root, segwit)
h_first, tx_req, txo_bin, bip143, segwit = await check_tx_fee(tx, root)
# Phase 2
# - sign inputs
@ -235,14 +237,25 @@ async def sign_tx(tx: SignTx, root, segwit=False):
write_varint(h_sign, tx.inputs_count)
if segwit:
txi = await request_tx_input(tx_req, i_sign)
# if hashType != ANYONE_CAN_PAY ? todo
if segwit[i_sign]:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
if txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS:
key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend meta
write_uint32(w_txi, tx.version)
write_varint(w_txi, 0x00) # segwit witness marker
write_varint(w_txi, 0x01) # segwit witness flag
write_varint(w_txi, tx.inputs_count)
write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi
tx_req.serialized = tx_ser
# todo: what to do with other types?
script_code = input_derive_script(txi, coin, root)
bip143.preimage(tx, txi, script_code)
# Return serialized input chunk ? todo
else:
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
@ -285,7 +298,7 @@ async def sign_tx(tx: SignTx, root, segwit=False):
txi_sign.script_sig = input_derive_script(
txi_sign, key_sign_pub, signature)
w_txi_sign = bytearray_with_cap(
len(txi_sign.prev_hash) + 4 + 5 + len(txi_sign.script_sig) + 4)
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend tx version and inputs count
write_uint32(w_txi_sign, tx.version)
write_varint(w_txi_sign, tx.inputs_count)
@ -306,14 +319,35 @@ async def sign_tx(tx: SignTx, root, segwit=False):
if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin)
if o == tx.outputs_count - 1: # serializing last output => append tx lock_time
write_uint32(w_txo_bin, tx.lock_time)
tx_ser.signature_index = None
tx_ser.signature_index = None # @todo delete?
tx_ser.signature = None
tx_ser.serialized_tx = w_txo_bin
tx_req.serialized = tx_ser
for i in range(tx.inputs_count):
if segwit[i]:
# STAGE_REQUEST_SEGWIT_WITNESS
txi = await request_tx_input(tx_req, i)
# todo check amount?
# if hashType != ANYONE_CAN_PAY ? todo
# todo: what to do with other types?
key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash(tx, txi, ecdsa_hash_pubkey(key_sign_pub))
signature = ecdsa_sign(key_sign, bip143_hash)
witness = get_p2wpkh_witness(signature, key_sign_pub)
tx_ser.serialized_tx = witness
tx_req.serialized = tx_ser
# else
# witness is 0x00
write_uint32(tx_ser.serialized_tx, tx.lock_time)
await request_tx_finish(tx_req)
@ -384,6 +418,7 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
def output_paytoaddress_extract_raw_address(
o: TxOutputType, coin: CoinType, root, p2sh: bool=False) -> bytes:
# todo if segwit then addr_type = p2sh ?
addr_type = coin.address_type_p2sh if p2sh else coin.address_type
# TODO: dont encode/decode more then necessary
if o.address_n is not None:
@ -415,8 +450,8 @@ def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) ->
else:
return script_spendaddress_new(pubkey, signature)
if i.script_type == InputScriptType.SPENDP2SHWITNESS: # todo
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh
return script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey))
else:
raise SigningError(FailureType.SyntaxError,
@ -471,6 +506,18 @@ def script_paytoscripthash_new(scripthash: bytes) -> bytearray:
return s
# P2WPKH is nested in P2SH to be backwards compatible
# see https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#witness-program
# this pushes 16 00 14 <pubkeyhash>
def script_p2wpkh_in_p2sh(pubkeyhash: bytes) -> bytearray:
w = bytearray_with_cap(3 + len(pubkeyhash))
write_op_push(w, len(pubkeyhash) + 2) # 0x16 - length of the redeemScript
w.append(0x00) # witness version byte
w.append(0x14) # P2WPKH witness program (pub key hash length + pub key hash)
write_bytes(w, pubkeyhash)
return w
def script_paytoopreturn_new(data: bytes) -> bytearray:
w = bytearray_with_cap(1 + 5 + len(data))
w.append(0x6A) # OP_RETURN
@ -481,9 +528,21 @@ def script_paytoopreturn_new(data: bytes) -> bytearray:
def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
w = bytearray_with_cap(5 + len(signature) + 1 + 5 + len(pubkey))
append_signature_and_pubkey(w, pubkey, signature)
return w
def get_p2wpkh_witness(signature: bytes, pubkey: bytes):
w = bytearray_with_cap(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
append_signature_and_pubkey(w, pubkey, signature)
return w
def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes) -> bytearray:
write_op_push(w, len(signature) + 1)
write_bytes(w, signature)
w.append(0x01)
w.append(0x01) # SIGHASH_ALL
write_op_push(w, len(pubkey))
write_bytes(w, pubkey)
return w

View File

@ -14,18 +14,19 @@ class TestSegwitBip143(unittest.TestCase):
tx = SignTx(coin_name='Bitcoin', version=1, lock_time=0x00000492, inputs_count=1, outputs_count=2)
inp1 = TxInputType(address_n=[0],
prev_hash=unhexlify('db6b1b20aa0fd7b23880be2ecbd4a98130974cf4748fb66092ac4d3ceb1a5477'),
# Trezor expects hash in reversed format
prev_hash=unhexlify('77541aeb3c4dac9260b68f74f44c973081a9d4cb2ebe8038b2d70faa201b6bdb'),
prev_index=1,
amount=1000000000, # 10 btc
script_type=InputScriptType.SPENDP2SHWITNESS, # todo is this correct?
sequence=0xfffffffe)
out1 = TxOutputType(address='1Fyxts6r24DpEieygQiNnWxUdb18ANa5p7',
amount=0x000000000bebb4b8,
script_type=OutputScriptType.PAYTOWITNESS,
script_type=OutputScriptType.PAYTOADDRESS,
address_n=None)
out2 = TxOutputType(address='1Q5YjKVj5yQWHBBsyEBamkfph3cA6G9KK8',
amount=0x000000002faf0800,
script_type=OutputScriptType.PAYTOWITNESS,
script_type=OutputScriptType.PAYTOADDRESS,
address_n=None)
def test_bip143_prevouts(self):
@ -72,10 +73,8 @@ class TestSegwitBip143(unittest.TestCase):
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
bip143.add_output(txo_bin)
# test data public key
script_code = input_derive_script(self.inp1, unhexlify('03ad1d8e89212f0b92c74d23bb710c00662ad1470198ac48c43f7d6f93a2a26873'))
self.assertEqual(hexlify(script_code), b'76a91479091972186c449eb1ded22b78e40d009bdf008988ac')
result = bip143.preimage(self.tx, self.inp1, script_code)
# test data public key hash
result = bip143.preimage_hash(self.tx, self.inp1, unhexlify('79091972186c449eb1ded22b78e40d009bdf0089'))
self.assertEqual(hexlify(result), b'64f3b0f4dd2bb3aa1ce8566d220cc74dda9df97d8490cc81d89d735c92e59fb6')

View File

@ -0,0 +1,125 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxAck import TxAck
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
class TestSignSegwitTx(unittest.TestCase):
# pylint: disable=C0301
def test_send_p2sh(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
address_n=[49 | 0x80000000, 1 | 0x80000000, 0 | 0x80000000, 1, 0],
amount=123456789,
prev_hash=unhexlify('20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337'),
prev_index=0,
script_type=InputScriptType.SPENDP2SHWITNESS,
sequence=0xffffffff,
)
out1 = TxOutputType(
address='mhRx1CeVfaayqRwq5zgRQmD7W5aWBfD5mC',
amount=12300000,
script_type=OutputScriptType.PAYTOADDRESS,
address_n=None,
)
out2 = TxOutputType(
address='2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX',
script_type=OutputScriptType.PAYTOSCRIPTHASH, # todo
amount=123456789 - 11000 - 12300000,
address_n=None,
)
tx = SignTx(coin_name='Testnet', version=None, lock_time=None, inputs_count=1, outputs_count=2)
messages = [
None,
# check fee
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmOutput(out2, coin),
True,
signing.UiConfirmTotal(123445789, 11000, coin),
True,
# sign tx
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
# returned serialized inp1
serialized_tx=unhexlify('0100000000010137c361fb8f2d9056ba8c98c5611930fcb48cacfdd0fe2e0449d83eea982f91200000000017160014d16b8c0680c61fc6ed2e407455715055e41052f5ffffffff'),
)),
TxAck(tx=TransactionType(outputs=[out1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=TxRequestSerializedType(
# returned serialized out1
serialized_tx=unhexlify('02e0aebb00000000001976a91414fdede0ddc3be652a0ce1afbc1b509a55b6b94888ac'),
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(outputs=[out2])),
# segwit
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType(
# returned serialized out2
serialized_tx=unhexlify('3df39f060000000017a91458b53ea7f832e8f096e896b8713a8c6df0e892ca87'),
signature_index=None,
signature=None,
)),
TxAck(tx=TransactionType(inputs=[inp1])),
TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType(
serialized_tx=unhexlify('02483045022100ccd253bfdf8a5593cd7b6701370c531199f0f05a418cd547dfc7da3f21515f0f02203fa08a0753688871c220648f9edadbdb98af42e5d8269364a326572cf703895b012103e7bfe10708f715e8538c92d46ca50db6f657bbc455b7494e6a0303ccdb868b7900000000'),
signature_index=None,
signature=None,
)),
]
signer = signing.sign_tx(tx, root)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()