1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-16 17:42:02 +00:00

src/apps/wallet: add support for zcash overwinter

This commit is contained in:
Pavol Rusnak 2018-06-05 16:04:23 +02:00
parent 4d1b2f0ca5
commit dfd02821af
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
9 changed files with 59 additions and 32 deletions

View File

@ -167,7 +167,7 @@ COINS = [
segwit=False, segwit=False,
fork_id=None, fork_id=None,
force_bip143=False, force_bip143=False,
version_group_id=63210096, version_group_id=0x03c48270,
), ),
CoinInfo( CoinInfo(
coin_name='Zcash Testnet', coin_name='Zcash Testnet',
@ -182,7 +182,7 @@ COINS = [
segwit=False, segwit=False,
fork_id=None, fork_id=None,
force_bip143=False, force_bip143=False,
version_group_id=63210096, version_group_id=0x03c48270,
), ),
CoinInfo( CoinInfo(
coin_name='Bgold', coin_name='Bgold',

View File

@ -105,7 +105,7 @@ def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0 tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0 tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin' tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
tx.decred_expiry = tx.decred_expiry if tx.decred_expiry is not None else 0 tx.expiry = tx.expiry if tx.expiry is not None else 0
tx.overwintered = tx.overwintered if tx.overwintered is not None else False tx.overwintered = tx.overwintered if tx.overwintered is not None else False
return tx return tx
@ -116,7 +116,7 @@ def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0 tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0 tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
tx.extra_data_len = tx.extra_data_len if tx.extra_data_len is not None else 0 tx.extra_data_len = tx.extra_data_len if tx.extra_data_len is not None else 0
tx.decred_expiry = tx.decred_expiry if tx.decred_expiry is not None else 0 tx.expiry = tx.expiry if tx.expiry is not None else 0
tx.overwintered = tx.overwintered if tx.overwintered is not None else False tx.overwintered = tx.overwintered if tx.overwintered is not None else False
return tx return tx

View File

@ -3,6 +3,7 @@ from trezor.messages.SignTx import SignTx
from trezor.messages import InputScriptType, FailureType from trezor.messages import InputScriptType, FailureType
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common.coininfo import CoinInfo
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.writers import *
from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig from apps.wallet.sign_tx.scripts import output_script_p2pkh, output_script_multisig
from apps.wallet.sign_tx.multisig import multisig_get_pubkeys from apps.wallet.sign_tx.multisig import multisig_get_pubkeys
@ -38,10 +39,15 @@ class Bip143:
def get_outputs_hash(self) -> bytes: def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs, True) return get_tx_hash(self.h_outputs, True)
def preimage_hash(self, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes: def preimage_hash(self, coin: CoinInfo, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
h_preimage = HashWriter(sha256) h_preimage = HashWriter(sha256)
write_uint32(h_preimage, tx.version) # nVersion if tx.overwintered:
write_uint32(h_preimage, tx.version | 0x80000000) # nVersion | fOverwintered
write_uint32(h_preimage, coin.version_group_id) # nVersionGroupId
else:
write_uint32(h_preimage, tx.version) # nVersion
write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts write_bytes(h_preimage, bytearray(self.get_prevouts_hash())) # hashPrevouts
write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # hashSequence write_bytes(h_preimage, bytearray(self.get_sequence_hash())) # hashSequence
write_bytes_rev(h_preimage, txi.prev_hash) # outpoint write_bytes_rev(h_preimage, txi.prev_hash) # outpoint
@ -56,6 +62,8 @@ class Bip143:
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime write_uint32(h_preimage, tx.lock_time) # nLockTime
if tx.overwintered:
write_uint32(h_preimage, tx.expiry) # expiryHeight
write_uint32(h_preimage, sighash) # nHashType write_uint32(h_preimage, sighash) # nHashType
return get_tx_hash(h_preimage, True) return get_tx_hash(h_preimage, True)

View File

@ -97,7 +97,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
elif txi.script_type in (InputScriptType.SPENDADDRESS, elif txi.script_type in (InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG): InputScriptType.SPENDMULTISIG):
if coin.force_bip143: if coin.force_bip143 or tx.overwintered:
if not txi.amount: if not txi.amount:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'BIP 143 input without amount') 'BIP 143 input without amount')
@ -107,7 +107,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
else: else:
segwit[i] = False segwit[i] = False
total_in += await get_prevtx_output_value( total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index) coin, tx_req, txi.prev_hash, txi.prev_index)
else: else:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
@ -194,12 +194,12 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
w_txi = bytearray_with_cap( w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(tx, True)) write_bytes(w_txi, get_tx_header(coin, tx, True))
write_tx_input(w_txi, txi_sign) write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi tx_ser.serialized_tx = w_txi
tx_req.serialized = tx_ser tx_req.serialized = tx_ser
elif coin.force_bip143: elif coin.force_bip143 or tx.overwintered:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign) txi_sign = await request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path) input_check_wallet_path(txi_sign, wallet_path)
@ -214,7 +214,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
key_sign = node_derive(root, txi_sign.address_n) key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash( bip143_hash = bip143.preimage_hash(
tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
# if multisig, check if singing with a key that is included in multisig # if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig: if txi_sign.multisig:
@ -230,7 +230,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
w_txi_sign = bytearray_with_cap( w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(tx)) write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign) write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign tx_ser.serialized_tx = w_txi_sign
@ -242,7 +242,12 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
# same as h_first, checked before signing the digest # same as h_first, checked before signing the digest
h_second = HashWriter(sha256) h_second = HashWriter(sha256)
write_uint32(h_sign, tx.version) if tx.overwintered:
write_uint32(h_sign, tx.version | 0x80000000) # nVersion | fOverwintered
write_uint32(h_sign, coin.version_group_id) # nVersionGroupId
else:
write_uint32(h_sign, tx.version) # nVersion
write_varint(h_sign, tx.inputs_count) write_varint(h_sign, tx.inputs_count)
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
@ -281,6 +286,8 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
write_tx_output(h_sign, txo_bin) write_tx_output(h_sign, txo_bin)
write_uint32(h_sign, tx.lock_time) write_uint32(h_sign, tx.lock_time)
if tx.overwintered:
write_uint32(h_sign, tx.expiry)
write_uint32(h_sign, get_hash_type(coin)) write_uint32(h_sign, get_hash_type(coin))
@ -304,7 +311,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
w_txi_sign = bytearray_with_cap( w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(tx)) write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign) write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign tx_ser.serialized_tx = w_txi_sign
@ -349,7 +356,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
key_sign = node_derive(root, txi.address_n) key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash( bip143_hash = bip143.preimage_hash(
tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin)) coin, tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
signature = ecdsa_sign(key_sign, bip143_hash) signature = ecdsa_sign(key_sign, bip143_hash)
if txi.multisig: if txi.multisig:
@ -370,11 +377,13 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
tx_req.serialized = tx_ser tx_req.serialized = tx_ser
write_uint32(tx_ser.serialized_tx, tx.lock_time) write_uint32(tx_ser.serialized_tx, tx.lock_time)
if tx.overwintered:
write_uint32(tx_ser.serialized_tx, tx.expiry)
await request_tx_finish(tx_req) await request_tx_finish(tx_req)
async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output_value(coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int) -> int:
total_out = 0 # sum of output amounts total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META # STAGE_REQUEST_2_PREV_META
@ -382,7 +391,12 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
txh = HashWriter(sha256) txh = HashWriter(sha256)
write_uint32(txh, tx.version) if tx.overwintered:
write_uint32(txh, tx.version | 0x80000000) # nVersion | fOverwintered
write_uint32(txh, coin.version_group_id) # nVersionGroupId
else:
write_uint32(txh, tx.version) # nVersion
write_varint(txh, tx.inputs_cnt) write_varint(txh, tx.inputs_cnt)
for i in range(tx.inputs_cnt): for i in range(tx.inputs_cnt):
@ -428,9 +442,13 @@ def get_hash_type(coin: CoinInfo) -> int:
return hashtype return hashtype
def get_tx_header(tx: SignTx, segwit: bool = False): def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False):
w_txi = bytearray() w_txi = bytearray()
write_uint32(w_txi, tx.version) if tx.overwintered:
write_uint32(w_txi, tx.version | 0x80000000) # nVersion | fOverwintered
write_uint32(w_txi, coin.version_group_id) # nVersionGroupId
else:
write_uint32(w_txi, tx.version) # nVersion
if segwit: if segwit:
write_varint(w_txi, 0x00) # segwit witness marker write_varint(w_txi, 0x00) # segwit witness marker
write_varint(w_txi, 0x01) # segwit witness flag write_varint(w_txi, 0x01) # segwit witness flag

View File

@ -10,7 +10,7 @@ class SignTx(p.MessageType):
3: ('coin_name', p.UnicodeType, 0), # default='Bitcoin' 3: ('coin_name', p.UnicodeType, 0), # default='Bitcoin'
4: ('version', p.UVarintType, 0), # default=1 4: ('version', p.UVarintType, 0), # default=1
5: ('lock_time', p.UVarintType, 0), # default=0 5: ('lock_time', p.UVarintType, 0), # default=0
6: ('decred_expiry', p.UVarintType, 0), 6: ('expiry', p.UVarintType, 0),
7: ('overwintered', p.BoolType, 0), 7: ('overwintered', p.BoolType, 0),
} }
@ -21,7 +21,7 @@ class SignTx(p.MessageType):
coin_name: str = None, coin_name: str = None,
version: int = None, version: int = None,
lock_time: int = None, lock_time: int = None,
decred_expiry: int = None, expiry: int = None,
overwintered: bool = None overwintered: bool = None
) -> None: ) -> None:
self.outputs_count = outputs_count self.outputs_count = outputs_count
@ -29,5 +29,5 @@ class SignTx(p.MessageType):
self.coin_name = coin_name self.coin_name = coin_name
self.version = version self.version = version
self.lock_time = lock_time self.lock_time = lock_time
self.decred_expiry = decred_expiry self.expiry = expiry
self.overwintered = overwintered self.overwintered = overwintered

View File

@ -19,7 +19,7 @@ class SimpleSignTx(p.MessageType):
4: ('coin_name', p.UnicodeType, 0), # default='Bitcoin' 4: ('coin_name', p.UnicodeType, 0), # default='Bitcoin'
5: ('version', p.UVarintType, 0), # default=1 5: ('version', p.UVarintType, 0), # default=1
6: ('lock_time', p.UVarintType, 0), # default=0 6: ('lock_time', p.UVarintType, 0), # default=0
7: ('decred_expiry', p.UVarintType, 0), 7: ('expiry', p.UVarintType, 0),
8: ('overwintered', p.BoolType, 0), 8: ('overwintered', p.BoolType, 0),
} }
@ -31,7 +31,7 @@ class SimpleSignTx(p.MessageType):
coin_name: str = None, coin_name: str = None,
version: int = None, version: int = None,
lock_time: int = None, lock_time: int = None,
decred_expiry: int = None, expiry: int = None,
overwintered: bool = None overwintered: bool = None
) -> None: ) -> None:
self.inputs = inputs if inputs is not None else [] self.inputs = inputs if inputs is not None else []
@ -40,5 +40,5 @@ class SimpleSignTx(p.MessageType):
self.coin_name = coin_name self.coin_name = coin_name
self.version = version self.version = version
self.lock_time = lock_time self.lock_time = lock_time
self.decred_expiry = decred_expiry self.expiry = expiry
self.overwintered = overwintered self.overwintered = overwintered

View File

@ -21,7 +21,7 @@ class TransactionType(p.MessageType):
7: ('outputs_cnt', p.UVarintType, 0), 7: ('outputs_cnt', p.UVarintType, 0),
8: ('extra_data', p.BytesType, 0), 8: ('extra_data', p.BytesType, 0),
9: ('extra_data_len', p.UVarintType, 0), 9: ('extra_data_len', p.UVarintType, 0),
10: ('decred_expiry', p.UVarintType, 0), 10: ('expiry', p.UVarintType, 0),
11: ('overwintered', p.BoolType, 0), 11: ('overwintered', p.BoolType, 0),
} }
@ -36,7 +36,7 @@ class TransactionType(p.MessageType):
outputs_cnt: int = None, outputs_cnt: int = None,
extra_data: bytes = None, extra_data: bytes = None,
extra_data_len: int = None, extra_data_len: int = None,
decred_expiry: int = None, expiry: int = None,
overwintered: bool = None overwintered: bool = None
) -> None: ) -> None:
self.version = version self.version = version
@ -48,5 +48,5 @@ class TransactionType(p.MessageType):
self.outputs_cnt = outputs_cnt self.outputs_cnt = outputs_cnt
self.extra_data = extra_data self.extra_data = extra_data
self.extra_data_len = extra_data_len self.extra_data_len = extra_data_len
self.decred_expiry = decred_expiry self.expiry = expiry
self.overwintered = overwintered self.overwintered = overwintered

View File

@ -29,10 +29,11 @@ for c in coins:
name = 'bitcoin_testnet' name = 'bitcoin_testnet'
data = json.load(open('../../vendor/trezor-common/defs/coins/%s.json' % name, 'r')) data = json.load(open('../../vendor/trezor-common/defs/coins/%s.json' % name, 'r'))
for n in fields: for n in fields:
if n == 'xpub_magic': if n in ['xpub_magic', 'version_group_id']:
print(' %s=0x%08x,' % (n, data[n])) v = '0x%08x' % data[n] if data[n] is not None else 'None'
else: else:
print(' %s=%s,' % (n, repr(data[n]))) v = repr(data[n])
print(' %s=%s,' % (n, v))
print(' ),') print(' ),')
print(']') print(']')

@ -1 +1 @@
Subproject commit 0f7118bb3d27c3c51a89c776c1b083db91f50541 Subproject commit 018eebac7e64ed082486d746d78d279fe815c65d