1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

signing: code style

This commit is contained in:
Jan Pochyla 2018-02-26 11:42:47 +01:00
parent c01ebeb552
commit 9291de47d0
2 changed files with 54 additions and 34 deletions

View File

@ -1,22 +1,21 @@
from micropython import const
from trezor.crypto.hashlib import sha256
from trezor.crypto import base58, bip32, der
from trezor.crypto.curve import secp256k1
from trezor.crypto import base58, der
from trezor.crypto.hashlib import sha256
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from apps.common import address_type
from apps.common import coins
from apps.common import address_type, coins
from apps.common.hash_writer import HashWriter
from apps.wallet.sign_tx.addresses import *
from apps.wallet.sign_tx.helpers import *
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.scripts import *
from apps.wallet.sign_tx.writers import *
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.tx_weight_calculator import *
from apps.common.hash_writer import HashWriter
from apps.wallet.sign_tx.writers import *
# the number of bip32 levels used in a wallet (chain and address)
_BIP32_WALLET_DEPTH = const(2)
@ -44,7 +43,7 @@ class SigningError(ValueError):
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
async def check_tx_fee(tx: SignTx, root):
async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
coin = coins.by_name(tx.coin_name)
# h_first is used to make sure the inputs and outputs streamed in Phase 1
@ -53,13 +52,9 @@ async def check_tx_fee(tx: SignTx, root):
h_first = HashWriter(sha256) # not a real tx hash
bip143 = Bip143() # bip143 transaction hashing
multifp = MultisigFingerprint() # control fp of multisig inputs
multifp = MultisigFingerprint() # control checksum of multisig inputs
weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count)
txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
total_in = 0 # sum of input amounts
segwit_in = 0 # sum of segwit input amounts
total_out = 0 # sum of output amounts
@ -67,6 +62,11 @@ async def check_tx_fee(tx: SignTx, root):
wallet_path = [] # common prefix of input paths
segwit = {} # dict of booleans stating if input is segwit
# output structures
txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
@ -115,11 +115,14 @@ async def check_tx_fee(tx: SignTx, root):
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
weight.add_output(txo_bin.script_pubkey)
if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multifp):
if change_out == 0 and is_change(txo, wallet_path, segwit_in, multifp):
# output is change and does not need confirmation
change_out = txo.amount
elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled,
'Output cancelled')
write_tx_output(h_first, txo_bin)
bip143.add_output(txo_bin)
total_out += txo_bin.amount
@ -142,7 +145,7 @@ async def check_tx_fee(tx: SignTx, root):
return h_first, bip143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root):
async def sign_tx(tx: SignTx, root: bip32.HDNode):
tx = sanitize_sign_tx(tx)
# Phase 1
@ -281,7 +284,7 @@ async def sign_tx(tx: SignTx, root):
# if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig:
pubkey_idx = multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
# compute the signature from the tx digest
signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, True))
@ -416,7 +419,7 @@ def get_hash_type(coin: CoinType) -> int:
return hashtype
def get_tx_header(tx: SignTx, segwit=False):
def get_tx_header(tx: SignTx, segwit: bool = False):
w_txi = bytearray()
write_uint32(w_txi, tx.version)
if segwit:
@ -430,14 +433,17 @@ def get_tx_header(tx: SignTx, segwit=False):
# ===
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
def output_derive_script(o: TxOutputType, coin: CoinType, root: bip32.HDNode) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output
if o.amount != 0:
raise SigningError(FailureType.DataError,
'OP_RETURN output with non-zero amount')
return output_script_paytoopreturn(o.op_return_data)
if o.address_n: # change output
if o.address_n:
# change output
if o.address:
raise SigningError(FailureType.DataError, 'Address in change output')
o.address = get_address_for_change(o, coin, root)
@ -445,24 +451,27 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if not o.address:
raise SigningError(FailureType.DataError, 'Missing address')
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog)
raw_address = base58.decode_check(o.address)
if address_type.check(coin.address_type, raw_address): # p2pkh
if address_type.check(coin.address_type, raw_address):
# p2pkh
pubkeyhash = address_type.strip(coin.address_type, raw_address)
return output_script_p2pkh(pubkeyhash)
elif address_type.check(coin.address_type_p2sh, raw_address): # p2sh
elif address_type.check(coin.address_type_p2sh, raw_address):
# p2sh
scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
return output_script_p2sh(scripthash)
raise SigningError(FailureType.DataError, 'Invalid address type')
def get_address_for_change(o: TxOutputType, coin: CoinType, root):
def get_address_for_change(o: TxOutputType, coin: CoinType, root: bip32.HDNode):
if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS
elif o.script_type == OutputScriptType.PAYTOMULTISIG:
@ -498,10 +507,15 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool
def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin)) # p2pkh or p2sh
# p2pkh or p2sh
return input_script_p2pkh_or_p2sh(
pubkey, signature, get_hash_type(coin))
if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh or p2wsh using p2sh
if i.multisig: # p2wsh in p2sh
if i.script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh
if i.multisig:
# p2wsh in p2sh
pubkeys = multisig_get_pubkeys(i.multisig)
witness_script = output_script_multisig(pubkeys, i.multisig.m)
witness_script_hash = sha256(witness_script).digest()
@ -510,14 +524,16 @@ def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature
# p2wpkh in p2sh
return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey))
elif i.script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or p2wsh
elif i.script_type == InputScriptType.SPENDWITNESS:
# native p2wpkh or p2wsh
return input_script_native_p2wpkh_or_p2wsh()
# multisig
elif i.script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig
signature_index = multisig_pubkey_index(i.multisig, pubkey)
return input_script_multisig(
i.multisig, signature, signature_index, get_hash_type(coin))
else:
raise SigningError(FailureType.ProcessError, 'Invalid script type')
@ -544,19 +560,23 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
'Transaction has changed during signing')
def node_derive(root, address_n: list):
def node_derive(root: bip32.HDNode, address_n: list):
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_sign(node, digest: bytes) -> bytes:
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65]))
return sigder
def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multifp: MultisigFingerprint) -> bool:
def is_change(
txo: TxOutputType,
wallet_path: list,
segwit_in: int,
multifp: MultisigFingerprint) -> bool:
if txo.multisig:
if not multifp.matches(txo.multisig):
return False

View File

@ -1,8 +1,8 @@
from trezor.crypto.hashlib import sha256
from apps.wallet.sign_tx.writers import *
# TX Serialization
# ===