1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-30 03:18:20 +00:00

apps/wallet/sign_tx: force_bip143

This commit is contained in:
Jan Pochyla 2018-02-08 18:59:51 +01:00
parent d0c742e884
commit 350de72c54
3 changed files with 78 additions and 23 deletions

View File

@ -12,9 +12,9 @@ from apps.wallet.sign_tx.writers import *
# =============== P2PKH =============== # =============== P2PKH ===============
def input_script_p2pkh_or_p2sh(pubkey: bytes, signature: bytes) -> bytearray: def input_script_p2pkh_or_p2sh(pubkey: bytes, signature: bytes, sighash: int) -> bytearray:
w = bytearray_with_cap(5 + len(signature) + 1 + 5 + len(pubkey)) w = bytearray_with_cap(5 + len(signature) + 1 + 5 + len(pubkey))
append_signature_and_pubkey(w, pubkey, signature) append_signature_and_pubkey(w, pubkey, signature, sighash)
return w return w
@ -111,10 +111,10 @@ def output_script_paytoopreturn(data: bytes) -> bytearray:
# === helpers # === helpers
def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes) -> bytearray: def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes, sighash: int) -> bytearray:
write_op_push(w, len(signature) + 1) write_op_push(w, len(signature) + 1)
write_bytes(w, signature) write_bytes(w, signature)
w.append(0x01) # SIGHASH_ALL w.append(sighash)
write_op_push(w, len(pubkey)) write_op_push(w, len(pubkey))
write_bytes(w, pubkey) write_bytes(w, pubkey)
return w return w

View File

@ -21,22 +21,22 @@ class Bip143:
write_bytes_rev(self.h_prevouts, txi.prev_hash) write_bytes_rev(self.h_prevouts, txi.prev_hash)
write_uint32(self.h_prevouts, txi.prev_index) write_uint32(self.h_prevouts, txi.prev_index)
def get_prevouts_hash(self) -> bytes:
return get_tx_hash(self.h_prevouts, True)
def add_sequence(self, txi: TxInputType): def add_sequence(self, txi: TxInputType):
write_uint32(self.h_sequence, txi.sequence) write_uint32(self.h_sequence, txi.sequence)
def get_sequence_hash(self) -> bytes:
return get_tx_hash(self.h_sequence, True)
def add_output(self, txo_bin: TxOutputBinType): def add_output(self, txo_bin: TxOutputBinType):
write_tx_output(self.h_outputs, txo_bin) write_tx_output(self.h_outputs, txo_bin)
def get_prevouts_hash(self) -> bytes:
return get_tx_hash(self.h_prevouts, True)
def get_sequence_hash(self) -> bytes:
return get_tx_hash(self.h_sequence, True)
def get_outputs_hash(self) -> bytes: def get_outputs_hash(self) -> bytes:
return get_tx_hash(self.h_outputs, True) return get_tx_hash(self.h_outputs, True)
def preimage_hash(self, tx: SignTx, txi: TxInputType, pubkeyhash) -> bytes: def preimage_hash(self, tx: SignTx, txi: TxInputType, pubkeyhash: bytes, sighash: int) -> bytes:
h_preimage = HashWriter(sha256) h_preimage = HashWriter(sha256)
write_uint32(h_preimage, tx.version) # nVersion write_uint32(h_preimage, tx.version) # nVersion
@ -54,7 +54,7 @@ class Bip143:
write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs write_bytes(h_preimage, bytearray(self.get_outputs_hash())) # hashOutputs
write_uint32(h_preimage, tx.lock_time) # nLockTime write_uint32(h_preimage, tx.lock_time) # nLockTime
write_uint32(h_preimage, 0x00000001) # nHashType - only SIGHASH_ALL currently write_uint32(h_preimage, sighash) # nHashType
return get_tx_hash(h_preimage, True) return get_tx_hash(h_preimage, True)
@ -62,9 +62,10 @@ class Bip143:
# for P2WPKH this is always 0x1976a914{20-byte-pubkey-hash}88ac # for P2WPKH this is always 0x1976a914{20-byte-pubkey-hash}88ac
def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray: def derive_script_code(self, txi: TxInputType, pubkeyhash: bytes) -> bytearray:
# p2wpkh in p2sh or native p2wpkh # p2wpkh in p2sh or native p2wpkh
is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or p2pkh = (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS) txi.script_type == InputScriptType.SPENDP2SHWITNESS or
if is_segwit: txi.script_type == InputScriptType.SPENDADDRESS)
if p2pkh:
s = bytearray(25) s = bytearray(25)
s[0] = 0x76 # OP_DUP s[0] = 0x76 # OP_DUP
s[1] = 0xA9 # OP_HASH_160 s[1] = 0xA9 # OP_HASH_160

View File

@ -76,7 +76,18 @@ async def check_tx_fee(tx: SignTx, root):
bip143.add_sequence(txi) bip143.add_sequence(txi)
is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS) txi.script_type == InputScriptType.SPENDP2SHWITNESS)
if is_segwit: if coin.force_bip143:
is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS)
if not is_bip143:
raise SigningError(FailureType.DataError,
'Wrong input script type')
if not txi.amount:
raise SigningError(FailureType.DataError,
'BIP 143 input without amount')
segwit[i] = False
segwit_in += txi.amount
total_in += txi.amount
elif is_segwit:
if not coin.segwit: if not coin.segwit:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'Segwit not enabled on this coin') 'Segwit not enabled on this coin')
@ -155,7 +166,39 @@ async def sign_tx(tx: SignTx, root):
key_sign = None key_sign = None
key_sign_pub = None key_sign_pub = None
if segwit[i_sign]: if coin.force_bip143:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path)
is_bip143 = (txi_sign.script_type == InputScriptType.SPENDADDRESS)
if not is_bip143 or txi_sign.amount > authorized_in:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash(
tx, txi_sign, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
signature = ecdsa_sign(key_sign, bip143_hash)
tx_ser.signature_index = i_sign
tx_ser.signature = signature
# serialize input with correct signature
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature)
w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(tx))
write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
tx_req.serialized = tx_ser
elif segwit[i_sign]:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign) txi_sign = await request_tx_input(tx_req, i_sign)
@ -168,7 +211,8 @@ async def sign_tx(tx: SignTx, root):
key_sign = node_derive(root, txi_sign.address_n) key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub) txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
w_txi = bytearray_with_cap( w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
@ -213,7 +257,7 @@ async def sign_tx(tx: SignTx, root):
write_uint32(h_sign, tx.lock_time) write_uint32(h_sign, tx.lock_time)
write_uint32(h_sign, 0x00000001) # SIGHASH_ALL hash_type write_uint32(h_sign, get_hash_type(coin))
# check the control digests # check the control digests
if get_tx_hash(h_first, False) != get_tx_hash(h_second, False): if get_tx_hash(h_first, False) != get_tx_hash(h_second, False):
@ -227,7 +271,7 @@ async def sign_tx(tx: SignTx, root):
# serialize input with correct signature # serialize input with correct signature
txi_sign.script_sig = input_derive_script( txi_sign.script_sig = input_derive_script(
txi_sign, key_sign_pub, signature) coin, txi_sign, key_sign_pub, signature)
w_txi_sign = bytearray_with_cap( w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) 5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers if i_sign == 0: # serializing first input => prepend headers
@ -273,7 +317,8 @@ async def sign_tx(tx: SignTx, root):
key_sign = node_derive(root, txi.address_n) key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
bip143_hash = bip143.preimage_hash(tx, txi, ecdsa_hash_pubkey(key_sign_pub)) bip143_hash = bip143.preimage_hash(
tx, txi, ecdsa_hash_pubkey(key_sign_pub), get_hash_type(coin))
signature = ecdsa_sign(key_sign, bip143_hash) signature = ecdsa_sign(key_sign, bip143_hash)
witness = get_p2wpkh_witness(signature, key_sign_pub) witness = get_p2wpkh_witness(signature, key_sign_pub)
@ -338,6 +383,15 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
# === # ===
def get_hash_type(coin: CoinType) -> int:
SIGHASH_FORKID = const(0x40)
SIGHASH_ALL = const(0x01)
hashtype = SIGHASH_ALL
if coin.forkid is not None:
hashtype |= (coin.forkid << 8) | SIGHASH_FORKID
return hashtype
def get_tx_header(tx: SignTx, segwit=False): def get_tx_header(tx: SignTx, segwit=False):
w_txi = bytearray() w_txi = bytearray()
write_uint32(w_txi, tx.version) write_uint32(w_txi, tx.version)
@ -425,9 +479,9 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool
# === # ===
def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes: def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS: if i.script_type == InputScriptType.SPENDADDRESS:
return input_script_p2pkh_or_p2sh(pubkey, signature) # p2pkh or p2sh return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin)) # p2pkh or p2sh
if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh
return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey)) return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey))
elif i.script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or p2wsh elif i.script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or p2wsh