wallet/signing: refactoring

pull/25/head
Tomas Susanka 7 years ago
parent e63d0adc23
commit b7f01baf99

@ -0,0 +1,121 @@
from trezor.messages.CoinType import CoinType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxInputType import TxInputType
from trezor.messages.SignTx import SignTx
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages import InputScriptType
# Machine instructions
# ===
class UiConfirmOutput:
def __init__(self, output: TxOutputType, coin: CoinType):
self.output = output
self.coin = coin
class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinType):
self.spending = spending
self.fee = fee
self.coin = coin
class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinType):
self.fee = fee
self.coin = coin
def confirm_output(output: TxOutputType, coin: CoinType):
return (yield UiConfirmOutput(output, coin))
def confirm_total(spending: int, fee: int, coin: CoinType):
return (yield UiConfirmTotal(spending, fee, coin))
def confirm_feeoverthreshold(fee: int, coin: CoinType):
return (yield UiConfirmFeeOverThreshold(fee, coin))
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
ack = yield tx_req
tx_req.serialized = None
return sanitize_tx_meta(ack.tx)
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
return sanitize_tx_input(ack.tx)
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return sanitize_tx_output(ack.tx)
else:
return sanitize_tx_binoutput(ack.tx)
def request_tx_finish(tx_req: TxRequest):
tx_req.request_type = TXFINISHED
tx_req.details = None
yield tx_req
tx_req.serialized = None
# Data sanitizers
# ===
def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
return tx
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
return tx
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
txi = tx.inputs[0]
txi.script_type = (
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
return txi
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
return tx.outputs[0]
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
return tx.bin_outputs[0]

@ -3,11 +3,6 @@ from trezor.crypto.curve import secp256k1
from trezor.crypto import base58, der
from trezor.utils import ensure
from trezor.messages.CoinType import CoinType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType
@ -16,124 +11,18 @@ from apps.common import address_type
from apps.common import coins
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.writers import *
# Machine instructions
# ===
from apps.wallet.sign_tx.helpers import *
class SigningError(ValueError):
pass
class UiConfirmOutput:
def __init__(self, output: TxOutputType, coin: CoinType):
self.output = output
self.coin = coin
class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinType):
self.spending = spending
self.fee = fee
self.coin = coin
class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinType):
self.fee = fee
self.coin = coin
def confirm_output(output: TxOutputType, coin: CoinType):
return (yield UiConfirmOutput(output, coin))
def confirm_total(spending: int, fee: int, coin: CoinType):
return (yield UiConfirmTotal(spending, fee, coin))
def confirm_feeoverthreshold(fee: int, coin: CoinType):
return (yield UiConfirmFeeOverThreshold(fee, coin))
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
tx_req.details.request_index = None
ack = yield tx_req
tx_req.serialized = None
return sanitize_tx_meta(ack.tx)
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.request_type = TXINPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
return sanitize_tx_input(ack.tx)
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
tx_req.request_type = TXOUTPUT
tx_req.details.request_index = i
tx_req.details.tx_hash = tx_hash
ack = yield tx_req
tx_req.serialized = None
if tx_hash is None:
return sanitize_tx_output(ack.tx)
else:
return sanitize_tx_binoutput(ack.tx)
def request_tx_finish(tx_req: TxRequest):
tx_req.request_type = TXFINISHED
tx_req.details = None
yield tx_req
tx_req.serialized = None
# Data sanitizers
# ===
def sanitize_sign_tx(tx: SignTx) -> SignTx:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
return tx
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
tx.version = tx.version if tx.version is not None else 1
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
return tx
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
txi = tx.inputs[0]
txi.script_type = (
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
return txi
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
return tx.outputs[0]
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
return tx.bin_outputs[0]
# Transaction signing
# ===
# see https://github.com/trezor/trezor-mcu/blob/master/firmware/signing.c#L84
# for pseudo code overview
# ===
# Phase 1
# - check inputs, previous transactions, and outputs
@ -246,11 +135,8 @@ async def sign_tx(tx: SignTx, root):
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend meta
write_uint32(w_txi, tx.version)
write_varint(w_txi, 0x00) # segwit witness marker
write_varint(w_txi, 0x01) # segwit witness flag
write_varint(w_txi, tx.inputs_count)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(tx, True))
write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi
@ -299,9 +185,8 @@ async def sign_tx(tx: SignTx, root):
txi_sign, key_sign_pub, signature)
w_txi_sign = bytearray_with_cap(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend tx version and inputs count
write_uint32(w_txi_sign, tx.version)
write_varint(w_txi_sign, tx.inputs_count)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(tx))
write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
@ -389,6 +274,26 @@ def estimate_tx_size(inputs, outputs):
return 10 + inputs * 149 + outputs * 35
# TX Helpers
# ===
def get_tx_header(tx: SignTx, segwit=False):
w_txi = bytearray()
write_uint32(w_txi, tx.version)
if segwit:
write_varint(w_txi, 0x00) # segwit witness marker
write_varint(w_txi, 0x01) # segwit witness flag
write_varint(w_txi, tx.inputs_count)
return w_txi
def get_p2wpkh_witness(signature: bytes, pubkey: bytes):
w = bytearray_with_cap(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
append_signature_and_pubkey(w, pubkey, signature)
return w
# TX Outputs
# ===
@ -532,13 +437,6 @@ def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
return w
def get_p2wpkh_witness(signature: bytes, pubkey: bytes):
w = bytearray_with_cap(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
append_signature_and_pubkey(w, pubkey, signature)
return w
def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes) -> bytearray:
write_op_push(w, len(signature) + 1)
write_bytes(w, signature)

@ -44,7 +44,7 @@ class TestSignSegwitTx(unittest.TestCase):
)
out2 = TxOutputType(
address='2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX',
script_type=OutputScriptType.PAYTOSCRIPTHASH, # todo
script_type=OutputScriptType.PAYTOSCRIPTHASH, # todo!
amount=123456789 - 11000 - 12300000,
address_n=None,
)

Loading…
Cancel
Save