1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 06:48:16 +00:00

signing: code style

This commit is contained in:
Jan Pochyla 2018-02-26 11:42:47 +01:00
parent c01ebeb552
commit 9291de47d0
2 changed files with 54 additions and 34 deletions

View File

@ -1,22 +1,21 @@
from micropython import const from micropython import const
from trezor.crypto.hashlib import sha256 from trezor.crypto import base58, bip32, der
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto import base58, der from trezor.crypto.hashlib import sha256
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType from trezor.messages import OutputScriptType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from apps.common import address_type from apps.common import address_type, coins
from apps.common import coins from apps.common.hash_writer import HashWriter
from apps.wallet.sign_tx.addresses import * from apps.wallet.sign_tx.addresses import *
from apps.wallet.sign_tx.helpers import * from apps.wallet.sign_tx.helpers import *
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.scripts import * from apps.wallet.sign_tx.scripts import *
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.tx_weight_calculator import * from apps.wallet.sign_tx.tx_weight_calculator import *
from apps.common.hash_writer import HashWriter from apps.wallet.sign_tx.writers import *
# the number of bip32 levels used in a wallet (chain and address) # the number of bip32 levels used in a wallet (chain and address)
_BIP32_WALLET_DEPTH = const(2) _BIP32_WALLET_DEPTH = const(2)
@ -44,7 +43,7 @@ class SigningError(ValueError):
# - check inputs, previous transactions, and outputs # - check inputs, previous transactions, and outputs
# - ask for confirmations # - ask for confirmations
# - check fee # - check fee
async def check_tx_fee(tx: SignTx, root): async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
coin = coins.by_name(tx.coin_name) coin = coins.by_name(tx.coin_name)
# h_first is used to make sure the inputs and outputs streamed in Phase 1 # h_first is used to make sure the inputs and outputs streamed in Phase 1
@ -53,13 +52,9 @@ async def check_tx_fee(tx: SignTx, root):
h_first = HashWriter(sha256) # not a real tx hash h_first = HashWriter(sha256) # not a real tx hash
bip143 = Bip143() # bip143 transaction hashing bip143 = Bip143() # bip143 transaction hashing
multifp = MultisigFingerprint() # control fp of multisig inputs multifp = MultisigFingerprint() # control checksum of multisig inputs
weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count)
txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
total_in = 0 # sum of input amounts total_in = 0 # sum of input amounts
segwit_in = 0 # sum of segwit input amounts segwit_in = 0 # sum of segwit input amounts
total_out = 0 # sum of output amounts total_out = 0 # sum of output amounts
@ -67,6 +62,11 @@ async def check_tx_fee(tx: SignTx, root):
wallet_path = [] # common prefix of input paths wallet_path = [] # common prefix of input paths
segwit = {} # dict of booleans stating if input is segwit segwit = {} # dict of booleans stating if input is segwit
# output structures
txo_bin = TxOutputBinType()
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT # STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i) txi = await request_tx_input(tx_req, i)
@ -115,11 +115,14 @@ async def check_tx_fee(tx: SignTx, root):
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
weight.add_output(txo_bin.script_pubkey) weight.add_output(txo_bin.script_pubkey)
if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multifp):
if change_out == 0 and is_change(txo, wallet_path, segwit_in, multifp):
# output is change and does not need confirmation
change_out = txo.amount change_out = txo.amount
elif not await confirm_output(txo, coin): elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled,
'Output cancelled') 'Output cancelled')
write_tx_output(h_first, txo_bin) write_tx_output(h_first, txo_bin)
bip143.add_output(txo_bin) bip143.add_output(txo_bin)
total_out += txo_bin.amount total_out += txo_bin.amount
@ -142,7 +145,7 @@ async def check_tx_fee(tx: SignTx, root):
return h_first, bip143, segwit, total_in, wallet_path return h_first, bip143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root): async def sign_tx(tx: SignTx, root: bip32.HDNode):
tx = sanitize_sign_tx(tx) tx = sanitize_sign_tx(tx)
# Phase 1 # Phase 1
@ -281,7 +284,7 @@ async def sign_tx(tx: SignTx, root):
# if multisig, check if singing with a key that is included in multisig # if multisig, check if singing with a key that is included in multisig
if txi_sign.multisig: if txi_sign.multisig:
pubkey_idx = multisig_pubkey_index(txi_sign.multisig, key_sign_pub) multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
# compute the signature from the tx digest # compute the signature from the tx digest
signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, True)) signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, True))
@ -416,7 +419,7 @@ def get_hash_type(coin: CoinType) -> int:
return hashtype return hashtype
def get_tx_header(tx: SignTx, segwit=False): def get_tx_header(tx: SignTx, segwit: bool = False):
w_txi = bytearray() w_txi = bytearray()
write_uint32(w_txi, tx.version) write_uint32(w_txi, tx.version)
if segwit: if segwit:
@ -430,14 +433,17 @@ def get_tx_header(tx: SignTx, segwit=False):
# === # ===
def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: def output_derive_script(o: TxOutputType, coin: CoinType, root: bip32.HDNode) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN: if o.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output
if o.amount != 0: if o.amount != 0:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'OP_RETURN output with non-zero amount') 'OP_RETURN output with non-zero amount')
return output_script_paytoopreturn(o.op_return_data) return output_script_paytoopreturn(o.op_return_data)
if o.address_n: # change output if o.address_n:
# change output
if o.address: if o.address:
raise SigningError(FailureType.DataError, 'Address in change output') raise SigningError(FailureType.DataError, 'Address in change output')
o.address = get_address_for_change(o, coin, root) o.address = get_address_for_change(o, coin, root)
@ -445,24 +451,27 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if not o.address: if not o.address:
raise SigningError(FailureType.DataError, 'Missing address') raise SigningError(FailureType.DataError, 'Missing address')
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address) witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog) return output_script_native_p2wpkh_or_p2wsh(witprog)
raw_address = base58.decode_check(o.address) raw_address = base58.decode_check(o.address)
if address_type.check(coin.address_type, raw_address): # p2pkh if address_type.check(coin.address_type, raw_address):
# p2pkh
pubkeyhash = address_type.strip(coin.address_type, raw_address) pubkeyhash = address_type.strip(coin.address_type, raw_address)
return output_script_p2pkh(pubkeyhash) return output_script_p2pkh(pubkeyhash)
elif address_type.check(coin.address_type_p2sh, raw_address): # p2sh elif address_type.check(coin.address_type_p2sh, raw_address):
# p2sh
scripthash = address_type.strip(coin.address_type_p2sh, raw_address) scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
return output_script_p2sh(scripthash) return output_script_p2sh(scripthash)
raise SigningError(FailureType.DataError, 'Invalid address type') raise SigningError(FailureType.DataError, 'Invalid address type')
def get_address_for_change(o: TxOutputType, coin: CoinType, root): def get_address_for_change(o: TxOutputType, coin: CoinType, root: bip32.HDNode):
if o.script_type == OutputScriptType.PAYTOADDRESS: if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS input_script_type = InputScriptType.SPENDADDRESS
elif o.script_type == OutputScriptType.PAYTOMULTISIG: elif o.script_type == OutputScriptType.PAYTOMULTISIG:
@ -498,10 +507,15 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool
def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes: def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature: bytes=None) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS: if i.script_type == InputScriptType.SPENDADDRESS:
return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin)) # p2pkh or p2sh # p2pkh or p2sh
return input_script_p2pkh_or_p2sh(
pubkey, signature, get_hash_type(coin))
if i.script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh or p2wsh using p2sh if i.script_type == InputScriptType.SPENDP2SHWITNESS:
if i.multisig: # p2wsh in p2sh # p2wpkh or p2wsh using p2sh
if i.multisig:
# p2wsh in p2sh
pubkeys = multisig_get_pubkeys(i.multisig) pubkeys = multisig_get_pubkeys(i.multisig)
witness_script = output_script_multisig(pubkeys, i.multisig.m) witness_script = output_script_multisig(pubkeys, i.multisig.m)
witness_script_hash = sha256(witness_script).digest() witness_script_hash = sha256(witness_script).digest()
@ -510,14 +524,16 @@ def input_derive_script(coin: CoinType, i: TxInputType, pubkey: bytes, signature
# p2wpkh in p2sh # p2wpkh in p2sh
return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey)) return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey))
elif i.script_type == InputScriptType.SPENDWITNESS: # native p2wpkh or p2wsh elif i.script_type == InputScriptType.SPENDWITNESS:
# native p2wpkh or p2wsh
return input_script_native_p2wpkh_or_p2wsh() return input_script_native_p2wpkh_or_p2wsh()
# multisig
elif i.script_type == InputScriptType.SPENDMULTISIG: elif i.script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig
signature_index = multisig_pubkey_index(i.multisig, pubkey) signature_index = multisig_pubkey_index(i.multisig, pubkey)
return input_script_multisig( return input_script_multisig(
i.multisig, signature, signature_index, get_hash_type(coin)) i.multisig, signature, signature_index, get_hash_type(coin))
else: else:
raise SigningError(FailureType.ProcessError, 'Invalid script type') raise SigningError(FailureType.ProcessError, 'Invalid script type')
@ -544,19 +560,23 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
'Transaction has changed during signing') 'Transaction has changed during signing')
def node_derive(root, address_n: list): def node_derive(root: bip32.HDNode, address_n: list):
node = root.clone() node = root.clone()
node.derive_path(address_n) node.derive_path(address_n)
return node return node
def ecdsa_sign(node, digest: bytes) -> bytes: def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest) sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65])) sigder = der.encode_seq((sig[1:33], sig[33:65]))
return sigder return sigder
def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multifp: MultisigFingerprint) -> bool: def is_change(
txo: TxOutputType,
wallet_path: list,
segwit_in: int,
multifp: MultisigFingerprint) -> bool:
if txo.multisig: if txo.multisig:
if not multifp.matches(txo.multisig): if not multifp.matches(txo.multisig):
return False return False

View File

@ -1,8 +1,8 @@
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.writers import *
# TX Serialization # TX Serialization
# === # ===