From d2c71b3a583d90b2186a5e0172eb51f9b2edddef Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Tue, 8 Nov 2016 18:49:58 +0100 Subject: [PATCH] signtx: add first test, make it all work --- src/apps/common/coins.py | 9 ++- src/apps/common/signtx.py | 115 ++++++++++++++++--------------- src/lib/protobuf.py | 11 +++ tests/test_apps_common.signtx.py | 95 +++++++++++++++++++++++++ 4 files changed, 173 insertions(+), 57 deletions(-) create mode 100644 tests/test_apps_common.signtx.py diff --git a/src/apps/common/coins.py b/src/apps/common/coins.py index 3fee1198b1..2689dadbf6 100644 --- a/src/apps/common/coins.py +++ b/src/apps/common/coins.py @@ -106,20 +106,23 @@ _coins = [ }, ] + def by_shortcut(shortcut): - for c in _couns: + for c in _coins: if c['coin_shortcut'] == shortcut: return c raise Exception('Unknown coin shortcut "%s"' % shortcut) + def by_name(name): - for c in _couns: + for c in _coins: if c['coin_name'] == name: return c raise Exception('Unknown coin name "%s"' % name) + def by_address_type(version): - for c in _couns: + for c in _coins: if c['address_type'] == version: return c raise Exception('Unknown coin address type %d' % version) diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 2c6e585aa4..c5f76b4c99 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -1,6 +1,6 @@ from trezor.crypto.hashlib import sha256, ripemd160 from trezor.crypto.curve import secp256k1 -from trezor.crypto import base58 +from trezor.crypto import base58, der from . import coins @@ -21,7 +21,7 @@ from trezor.messages import OutputScriptType, InputScriptType def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): - tx_req.type = TXMETA + tx_req.request_type = TXMETA tx_req.details.tx_hash = tx_hash tx_req.details.request_index = None ack = yield tx_req @@ -30,7 +30,7 @@ def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): - tx_req.type = TXINPUT + tx_req.request_type = TXINPUT tx_req.details.request_index = i tx_req.details.tx_hash = tx_hash ack = yield tx_req @@ -39,19 +39,19 @@ def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None): def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None): - tx_req.type = TXOUTPUT + tx_req.request_type = TXOUTPUT tx_req.details.request_index = i tx_req.details.tx_hash = tx_hash ack = yield tx_req tx_req.serialized = None if tx_hash is None: - return ack.outputs[0] + return ack.tx.outputs[0] else: - return ack.bin_outputs[0] + return ack.tx.bin_outputs[0] def request_tx_finish(tx_req: TxRequest): - tx_req.type = TXFINISHED + tx_req.request_type = TXFINISHED tx_req.details = None yield tx_req tx_req.serialized = None @@ -62,8 +62,8 @@ def request_tx_finish(tx_req: TxRequest): async def sign_tx(tx: SignTx, root): - tx_version = getattr(tx, 'version', 0) - tx_lock_time = getattr(tx, 'lock_time', 1) + tx_version = getattr(tx, 'version', 1) + tx_lock_time = getattr(tx, 'lock_time', 0) tx_inputs_count = getattr(tx, 'inputs_count', 0) tx_outputs_count = getattr(tx, 'outputs_count', 0) coin_name = getattr(tx, 'coin_name', 'Bitcoin') @@ -91,7 +91,7 @@ async def sign_tx(tx: SignTx, root): for i in range(tx_inputs_count): # STAGE_REQUEST_1_INPUT txi = await request_tx_input(tx_req, i) - write_tx_input(h_first, txi) + write_tx_input_check(h_first, txi) total_in += await get_prevtx_output_value( tx_req, txi.prev_hash, txi.prev_index) @@ -118,9 +118,10 @@ async def sign_tx(tx: SignTx, root): tx_ser = TxRequestSerializedType() - for i_sign in range(tx.inputs_count): + for i_sign in range(tx_inputs_count): # hash of what we are signing with this input h_sign = HashWriter(sha256) + # h_sign = BufferWriter() # same as h_first, checked at the end of this iteration h_second = HashWriter(sha256) @@ -130,10 +131,10 @@ async def sign_tx(tx: SignTx, root): write_tx_header(h_sign, tx_version, tx_inputs_count) - for i in range(tx.inputs_count): + for i in range(tx_inputs_count): # STAGE_REQUEST_4_INPUT txi = await request_tx_input(tx_req, i) - write_tx_input(h_second, txi) + write_tx_input_check(h_second, txi) if i == i_sign: txi_sign = txi key_sign = node_derive(root, txi.address_n) @@ -146,7 +147,7 @@ async def sign_tx(tx: SignTx, root): write_tx_middle(h_sign, tx_outputs_count) - for o in range(tx.outputs_count): + for o in range(tx_outputs_count): # STAGE_REQUEST_4_OUTPUT txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount @@ -156,30 +157,29 @@ async def sign_tx(tx: SignTx, root): write_tx_footer(h_sign, tx_lock_time, True) + import ubinascii + # check the control digests - h_first_dig = tx_hash_digest(h_first, False) - h_second_dig = tx_hash_digest(h_second, False) - if h_first_dig != h_second_dig: + if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False): raise ValueError('Transaction has changed during signing') # compute the signature from the tx digest - h_sign_dig = tx_hash_digest(h_sign, True) - signature = ecdsa_sign(key_sign, h_sign_dig) + signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True)) tx_ser.signature_index = i_sign tx_ser.signature = signature # serialize input with correct signature txi_sign.script_sig = input_derive_script_post_sign( txi_sign, key_sign_pub, signature) - txi_sign_w = BufferWriter() + w_txi_sign = BufferWriter() if i_sign == 0: - write_tx_header(txi_sign_w, tx_version, tx_inputs_count) - write_tx_input(txi_sign_w, txi_sign) - tx_ser.serialized_tx = txi_sign_w.getvalue() + write_tx_header(w_txi_sign, tx_version, tx_inputs_count) + write_tx_input(w_txi_sign, txi_sign) + tx_ser.serialized_tx = w_txi_sign.getvalue() tx_req.serialized = tx_ser - for o in range(tx.outputs_count): + for o in range(tx_outputs_count): # STAGE_REQUEST_5_OUTPUT txo = await request_tx_output(tx_req, o) txo_bin.amount = txo.amount @@ -190,7 +190,7 @@ async def sign_tx(tx: SignTx, root): if o == 0: write_tx_middle(w_txo_bin, tx_outputs_count) write_tx_output(w_txo_bin, txo_bin) - if o == tx_outputs_count: + if o == tx_outputs_count - 1: write_tx_footer(w_txo_bin, tx_lock_time, False) tx_ser.signature_index = None tx_ser.signature = None @@ -205,12 +205,12 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde total_out = 0 # sum of output amounts # STAGE_REQUEST_2_PREV_META - tx = await request_tx_meta(prev_hash) + tx = await request_tx_meta(tx_req, prev_hash) tx_version = getattr(tx, 'version', 0) tx_lock_time = getattr(tx, 'lock_time', 1) - tx_inputs_count = getattr(tx, 'inputs_count', 0) - tx_outputs_count = getattr(tx, 'outputs_count', 0) + tx_inputs_count = getattr(tx, 'inputs_cnt', 0) + tx_outputs_count = getattr(tx, 'outputs_cnt', 0) txh = HashWriter(sha256) @@ -228,16 +228,17 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde txo_bin = await request_tx_output(tx_req, o, prev_hash) write_tx_output(txh, txo_bin) if o == prev_index: - total_out += txo_bin.value + total_out += txo_bin.amount write_tx_footer(txh, tx_lock_time, False) - if tx_hash_digest(txh, True) != prev_hash: + prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance + if tx_hash_digest(txh, True) != prev_hash_rev: raise ValueError('Encountered invalid prev_hash') return total_out -def tx_hash_digest(w, double: bool): +def tx_hash_digest(w, double: bool) -> bytes: d = w.getvalue() if double: d = sha256(d).digest() @@ -250,8 +251,8 @@ def tx_hash_digest(w, double: bool): def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: if o.script_type == OutputScriptType.PAYTOADDRESS: - return script_paytoaddress_new( - output_paytoaddress_extract_raw_address(o, coin, root)) + ra = output_paytoaddress_extract_raw_address(o, coin, root) + return script_paytoaddress_new(ra[1:]) else: raise ValueError('Invalid output script type') return @@ -269,7 +270,7 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo raw_address = base58.decode_check(o_address) else: raise ValueError('Missing address') - if raw_address[0] != coin.address_type: + if raw_address[0] != coin['address_type']: raise ValueError('Invalid address type') return raw_address @@ -284,14 +285,16 @@ def output_is_change(output: TxOutputType): def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes: - if i.script_type == InputScriptType.SPENDADDRESS: + i_script_type = getattr(i, 'script_type', InputScriptType.SPENDADDRESS) + if i_script_type == InputScriptType.SPENDADDRESS: return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) else: raise ValueError('Unknown input script type') def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes: - if i.script_type == InputScriptType.SPENDADDRESS: + i_script_type = getattr(i, 'script_type', InputScriptType.SPENDADDRESS) + if i_script_type == InputScriptType.SPENDADDRESS: return script_spendaddress_new(pubkey, signature) else: raise ValueError('Unknown input script type') @@ -315,20 +318,23 @@ def ecdsa_hash_pubkey(pubkey: bytes) -> bytes: return h -def ecdsa_sign(privkey: bytes, digest: bytes) -> bytes: - return secp256k1.sign(privkey, digest) +def ecdsa_sign(node, digest: bytes) -> bytes: + sig = secp256k1.sign(node.private_key(), digest) + print(len(sig)) + sigder = der.convert_seq((sig[:32], sig[32:])) + return sigder # TX Scripts # === -def script_paytoaddress_new(raw_address: bytes) -> bytearray: +def script_paytoaddress_new(pubkeyhash: bytes) -> bytearray: s = bytearray(25) s[0] = 0x76 # OP_DUP s[1] = 0xA9 # OP_HASH_160 s[2] = 0x14 # pushing 20 bytes - s[3:23] = raw_address[1:] # TODO: do this properly + s[3:23] = pubkeyhash s[23] = 0x88 # OP_EQUALVERIFY s[24] = 0xAC # OP_CHECKSIG return s @@ -354,11 +360,22 @@ def write_tx_header(w, version: int, inputs_count: int): def write_tx_input(w, i: TxInputType): + i_sequence = getattr(i, 'sequence', 4294967295) write_bytes_rev(w, i.prev_hash) write_uint32(w, i.prev_index) write_varint(w, len(i.script_sig)) write_bytes(w, i.script_sig) - write_uint32(w, i.sequence) + write_uint32(w, i_sequence) + + +def write_tx_input_check(w, i: TxInputType): + i_sequence = getattr(i, 'sequence', 4294967295) + write_bytes(w, i.prev_hash) + write_uint32(w, i.prev_index) + write_uint32(w, len(i.address_n)) + for n in i.address_n: + write_uint32(w, n) + write_uint32(w, i_sequence) def write_tx_middle(w, outputs_count: int): @@ -446,26 +463,16 @@ def write_bytes_rev(w, buf: bytearray): class BufferWriter: - def __init__(self, buf: bytearray=None, ofs: int=0): - # TODO: re-think the use of bytearrays, buffers, and other byte IO - # i think we should just pass a pre-allocation size here, allocate the - # bytearray and then trim it to zero. in this case, write() would - # correspond to extend(), and writebyte() to append(). of course, the - # the use-case of non-destructively writing to existing bytearray still - # exists. + def __init__(self, buf: bytearray=None): if buf is None: buf = bytearray() self.buf = buf - self.ofs = ofs def write(self, buf: bytearray): - n = len(buf) - self.buf[self.ofs:self.ofs + n] = buf - self.ofs += n + self.buf.extend(buf) def writebyte(self, b: int): - self.buf[self.ofs] = b - self.ofs += 1 + self.buf.append(b) def getvalue(self) -> bytearray: return self.buf diff --git a/src/lib/protobuf.py b/src/lib/protobuf.py index f688629610..03b115c643 100644 --- a/src/lib/protobuf.py +++ b/src/lib/protobuf.py @@ -140,6 +140,17 @@ class MessageType(Type): WIRE_TYPE = 2 FIELDS = {} + def __init__(self, **kwargs): + for kw in kwargs: + setattr(self, kw, kwargs[kw]) + + def __eq__(self, rhs): + return (self.__class__ is rhs.__class__ and + self.__dict__ == rhs.__dict__) + + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self.__dict__) + @classmethod async def load(cls, source=None, target=None): if target is None: diff --git a/tests/test_apps_common.signtx.py b/tests/test_apps_common.signtx.py new file mode 100644 index 0000000000..102e8f1624 --- /dev/null +++ b/tests/test_apps_common.signtx.py @@ -0,0 +1,95 @@ +from common import * + +from trezor.crypto import bip32, bip39 +from trezor.messages.SignTx import SignTx +from trezor.messages.TxInputType import TxInputType +from trezor.messages.TxOutputType import TxOutputType +from trezor.messages.TxOutputBinType import TxOutputBinType +from trezor.messages.TxRequest import TxRequest +from trezor.messages.TxAck import TxAck +from trezor.messages.TransactionType import TransactionType +from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED +from trezor.messages.TxRequestDetailsType import TxRequestDetailsType +from trezor.messages.TxRequestSerializedType import TxRequestSerializedType +from trezor.messages import OutputScriptType, InputScriptType + +from apps.common import signtx + + +class TestSignTx(unittest.TestCase): + # pylint: disable=C0301 + + def test_one_one_fee(self): + # tx: d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882 + # input 0: 0.0039 BTC + + ptx1 = TransactionType(version=1, lock_time=0, inputs_cnt=2, outputs_cnt=1) + pinp1 = TxInputType(script_sig=unhexlify(b'483045022072ba61305fe7cb542d142b8f3299a7b10f9ea61f6ffaab5dca8142601869d53c0221009a8027ed79eb3b9bc13577ac2853269323434558528c6b6a7e542be46e7e9a820141047a2d177c0f3626fc68c53610b0270fa6156181f46586c679ba6a88b34c6f4874686390b4d92e5769fbb89c8050b984f4ec0b257a0e5c4ff8bd3b035a51709503'), + prev_hash=unhexlify(b'c16a03f1cf8f99f6b5297ab614586cacec784c2d259af245909dedb0e39eddcf'), + prev_index=1) + pinp2 = TxInputType(script_sig=unhexlify(b'48304502200fd63adc8f6cb34359dc6cca9e5458d7ea50376cbd0a74514880735e6d1b8a4c0221008b6ead7fe5fbdab7319d6dfede3a0bc8e2a7c5b5a9301636d1de4aa31a3ee9b101410486ad608470d796236b003635718dfc07c0cac0cfc3bfc3079e4f491b0426f0676e6643a39198e8e7bdaffb94f4b49ea21baa107ec2e237368872836073668214'), + prev_hash=unhexlify(b'1ae39a2f8d59670c8fc61179148a8e61e039d0d9e8ab08610cb69b4a19453eaf'), + prev_index=1) + pout1 = TxOutputBinType(script_pubkey=unhexlify(b'76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac'), + amount=390000) + + inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e + # amount=390000, + prev_hash=unhexlify(b'd5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'), + prev_index=0) + out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1', + amount=390000 - 10000, + script_type=OutputScriptType.PAYTOADDRESS) + tx = SignTx(inputs_count=1, outputs_count=1) + + messages = [ + None, + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)), + TxAck(tx=TransactionType(inputs=[inp1])), + TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None), + TxAck(tx=ptx1), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None), + TxAck(tx=TransactionType(inputs=[pinp1])), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=1, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None), + TxAck(tx=TransactionType(inputs=[pinp2])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=unhexlify(b"d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882")), serialized=None), + TxAck(tx=TransactionType(bin_outputs=[pout1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(outputs=[out1])), + # ButtonRequest(code=ButtonRequest_ConfirmOutput), + # ButtonRequest(code=ButtonRequest_SignTx), + TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(inputs=[inp1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), + TxAck(tx=TransactionType(outputs=[out1])), + TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=TxRequestSerializedType( + signature_index=0, + signature=unhexlify(b'30450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede781'), + serialized_tx=unhexlify(b'010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff'))), + TxAck(tx=TransactionType(outputs=[out1])), + TxRequest(request_type=TXFINISHED, details=None, serialized=TxRequestSerializedType( + signature_index=None, + signature=None, + serialized_tx=unhexlify(b'0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'), + )), + ] + + seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '') + root = bip32.from_seed(seed, 'secp256k1') + + signer = signtx.sign_tx(tx, root) + i = 0 + try: + for i in range(0, len(messages) - 1, 2): + res = signer.send(messages[i]) + self.assertEqual(res, messages[i + 1]) + except StopIteration: + pass + self.assertEqual(i, len(messages) - 2) + + # Accepted by network: tx fd79435246dee76b2f159d2db08032d666c95adc544de64c8c49f474df4a7fee + # self.assertEqual(hexlify(serialized_tx), b'010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000') + + +if __name__ == '__main__': + unittest.main()