mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-06 14:52:33 +00:00
wallet/signing: refactoring
This commit is contained in:
parent
e63d0adc23
commit
b7f01baf99
121
src/apps/wallet/sign_tx/helpers.py
Normal file
121
src/apps/wallet/sign_tx/helpers.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
|
||||||
|
from trezor.messages.CoinType import CoinType
|
||||||
|
from trezor.messages.TxOutputType import TxOutputType
|
||||||
|
from trezor.messages.TxOutputBinType import TxOutputBinType
|
||||||
|
from trezor.messages.TxInputType import TxInputType
|
||||||
|
from trezor.messages.SignTx import SignTx
|
||||||
|
from trezor.messages.TxRequest import TxRequest
|
||||||
|
from trezor.messages.TransactionType import TransactionType
|
||||||
|
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
||||||
|
from trezor.messages import InputScriptType
|
||||||
|
|
||||||
|
# Machine instructions
|
||||||
|
# ===
|
||||||
|
|
||||||
|
|
||||||
|
class UiConfirmOutput:
|
||||||
|
|
||||||
|
def __init__(self, output: TxOutputType, coin: CoinType):
|
||||||
|
self.output = output
|
||||||
|
self.coin = coin
|
||||||
|
|
||||||
|
|
||||||
|
class UiConfirmTotal:
|
||||||
|
|
||||||
|
def __init__(self, spending: int, fee: int, coin: CoinType):
|
||||||
|
self.spending = spending
|
||||||
|
self.fee = fee
|
||||||
|
self.coin = coin
|
||||||
|
|
||||||
|
|
||||||
|
class UiConfirmFeeOverThreshold:
|
||||||
|
|
||||||
|
def __init__(self, fee: int, coin: CoinType):
|
||||||
|
self.fee = fee
|
||||||
|
self.coin = coin
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_output(output: TxOutputType, coin: CoinType):
|
||||||
|
return (yield UiConfirmOutput(output, coin))
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_total(spending: int, fee: int, coin: CoinType):
|
||||||
|
return (yield UiConfirmTotal(spending, fee, coin))
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_feeoverthreshold(fee: int, coin: CoinType):
|
||||||
|
return (yield UiConfirmFeeOverThreshold(fee, coin))
|
||||||
|
|
||||||
|
|
||||||
|
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
||||||
|
tx_req.request_type = TXMETA
|
||||||
|
tx_req.details.tx_hash = tx_hash
|
||||||
|
tx_req.details.request_index = None
|
||||||
|
ack = yield tx_req
|
||||||
|
tx_req.serialized = None
|
||||||
|
return sanitize_tx_meta(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
|
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||||
|
tx_req.request_type = TXINPUT
|
||||||
|
tx_req.details.request_index = i
|
||||||
|
tx_req.details.tx_hash = tx_hash
|
||||||
|
ack = yield tx_req
|
||||||
|
tx_req.serialized = None
|
||||||
|
return sanitize_tx_input(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
|
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
||||||
|
tx_req.request_type = TXOUTPUT
|
||||||
|
tx_req.details.request_index = i
|
||||||
|
tx_req.details.tx_hash = tx_hash
|
||||||
|
ack = yield tx_req
|
||||||
|
tx_req.serialized = None
|
||||||
|
if tx_hash is None:
|
||||||
|
return sanitize_tx_output(ack.tx)
|
||||||
|
else:
|
||||||
|
return sanitize_tx_binoutput(ack.tx)
|
||||||
|
|
||||||
|
|
||||||
|
def request_tx_finish(tx_req: TxRequest):
|
||||||
|
tx_req.request_type = TXFINISHED
|
||||||
|
tx_req.details = None
|
||||||
|
yield tx_req
|
||||||
|
tx_req.serialized = None
|
||||||
|
|
||||||
|
|
||||||
|
# Data sanitizers
|
||||||
|
# ===
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_sign_tx(tx: SignTx) -> SignTx:
|
||||||
|
tx.version = tx.version if tx.version is not None else 1
|
||||||
|
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||||
|
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
|
||||||
|
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
|
||||||
|
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
|
||||||
|
return tx
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
|
||||||
|
tx.version = tx.version if tx.version is not None else 1
|
||||||
|
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
||||||
|
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
|
||||||
|
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
|
||||||
|
return tx
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
|
||||||
|
txi = tx.inputs[0]
|
||||||
|
txi.script_type = (
|
||||||
|
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
|
||||||
|
return txi
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
|
||||||
|
return tx.outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
|
||||||
|
return tx.bin_outputs[0]
|
||||||
|
|
@ -3,11 +3,6 @@ from trezor.crypto.curve import secp256k1
|
|||||||
from trezor.crypto import base58, der
|
from trezor.crypto import base58, der
|
||||||
from trezor.utils import ensure
|
from trezor.utils import ensure
|
||||||
|
|
||||||
from trezor.messages.CoinType import CoinType
|
|
||||||
from trezor.messages.TxOutputType import TxOutputType
|
|
||||||
from trezor.messages.TxRequest import TxRequest
|
|
||||||
from trezor.messages.TransactionType import TransactionType
|
|
||||||
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
|
||||||
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||||
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||||
from trezor.messages import OutputScriptType
|
from trezor.messages import OutputScriptType
|
||||||
@ -16,124 +11,18 @@ from apps.common import address_type
|
|||||||
from apps.common import coins
|
from apps.common import coins
|
||||||
from apps.wallet.sign_tx.segwit_bip143 import *
|
from apps.wallet.sign_tx.segwit_bip143 import *
|
||||||
from apps.wallet.sign_tx.writers import *
|
from apps.wallet.sign_tx.writers import *
|
||||||
|
from apps.wallet.sign_tx.helpers import *
|
||||||
# Machine instructions
|
|
||||||
# ===
|
|
||||||
|
|
||||||
|
|
||||||
class SigningError(ValueError):
|
class SigningError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UiConfirmOutput:
|
|
||||||
|
|
||||||
def __init__(self, output: TxOutputType, coin: CoinType):
|
|
||||||
self.output = output
|
|
||||||
self.coin = coin
|
|
||||||
|
|
||||||
|
|
||||||
class UiConfirmTotal:
|
|
||||||
|
|
||||||
def __init__(self, spending: int, fee: int, coin: CoinType):
|
|
||||||
self.spending = spending
|
|
||||||
self.fee = fee
|
|
||||||
self.coin = coin
|
|
||||||
|
|
||||||
|
|
||||||
class UiConfirmFeeOverThreshold:
|
|
||||||
|
|
||||||
def __init__(self, fee: int, coin: CoinType):
|
|
||||||
self.fee = fee
|
|
||||||
self.coin = coin
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_output(output: TxOutputType, coin: CoinType):
|
|
||||||
return (yield UiConfirmOutput(output, coin))
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_total(spending: int, fee: int, coin: CoinType):
|
|
||||||
return (yield UiConfirmTotal(spending, fee, coin))
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_feeoverthreshold(fee: int, coin: CoinType):
|
|
||||||
return (yield UiConfirmFeeOverThreshold(fee, coin))
|
|
||||||
|
|
||||||
|
|
||||||
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
|
||||||
tx_req.request_type = TXMETA
|
|
||||||
tx_req.details.tx_hash = tx_hash
|
|
||||||
tx_req.details.request_index = None
|
|
||||||
ack = yield tx_req
|
|
||||||
tx_req.serialized = None
|
|
||||||
return sanitize_tx_meta(ack.tx)
|
|
||||||
|
|
||||||
|
|
||||||
def request_tx_input(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|
||||||
tx_req.request_type = TXINPUT
|
|
||||||
tx_req.details.request_index = i
|
|
||||||
tx_req.details.tx_hash = tx_hash
|
|
||||||
ack = yield tx_req
|
|
||||||
tx_req.serialized = None
|
|
||||||
return sanitize_tx_input(ack.tx)
|
|
||||||
|
|
||||||
|
|
||||||
def request_tx_output(tx_req: TxRequest, i: int, tx_hash: bytes=None):
|
|
||||||
tx_req.request_type = TXOUTPUT
|
|
||||||
tx_req.details.request_index = i
|
|
||||||
tx_req.details.tx_hash = tx_hash
|
|
||||||
ack = yield tx_req
|
|
||||||
tx_req.serialized = None
|
|
||||||
if tx_hash is None:
|
|
||||||
return sanitize_tx_output(ack.tx)
|
|
||||||
else:
|
|
||||||
return sanitize_tx_binoutput(ack.tx)
|
|
||||||
|
|
||||||
|
|
||||||
def request_tx_finish(tx_req: TxRequest):
|
|
||||||
tx_req.request_type = TXFINISHED
|
|
||||||
tx_req.details = None
|
|
||||||
yield tx_req
|
|
||||||
tx_req.serialized = None
|
|
||||||
|
|
||||||
|
|
||||||
# Data sanitizers
|
|
||||||
# ===
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_sign_tx(tx: SignTx) -> SignTx:
|
|
||||||
tx.version = tx.version if tx.version is not None else 1
|
|
||||||
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
|
||||||
tx.inputs_count = tx.inputs_count if tx.inputs_count is not None else 0
|
|
||||||
tx.outputs_count = tx.outputs_count if tx.outputs_count is not None else 0
|
|
||||||
tx.coin_name = tx.coin_name if tx.coin_name is not None else 'Bitcoin'
|
|
||||||
return tx
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_tx_meta(tx: TransactionType) -> TransactionType:
|
|
||||||
tx.version = tx.version if tx.version is not None else 1
|
|
||||||
tx.lock_time = tx.lock_time if tx.lock_time is not None else 0
|
|
||||||
tx.inputs_cnt = tx.inputs_cnt if tx.inputs_cnt is not None else 0
|
|
||||||
tx.outputs_cnt = tx.outputs_cnt if tx.outputs_cnt is not None else 0
|
|
||||||
return tx
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_tx_input(tx: TransactionType) -> TxInputType:
|
|
||||||
txi = tx.inputs[0]
|
|
||||||
txi.script_type = (
|
|
||||||
txi.script_type if txi.script_type is not None else InputScriptType.SPENDADDRESS)
|
|
||||||
return txi
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_tx_output(tx: TransactionType) -> TxOutputType:
|
|
||||||
return tx.outputs[0]
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
|
|
||||||
return tx.bin_outputs[0]
|
|
||||||
|
|
||||||
|
|
||||||
# Transaction signing
|
# Transaction signing
|
||||||
# ===
|
# ===
|
||||||
|
# see https://github.com/trezor/trezor-mcu/blob/master/firmware/signing.c#L84
|
||||||
|
# for pseudo code overview
|
||||||
|
# ===
|
||||||
|
|
||||||
# Phase 1
|
# Phase 1
|
||||||
# - check inputs, previous transactions, and outputs
|
# - check inputs, previous transactions, and outputs
|
||||||
@ -246,11 +135,8 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
|
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
|
||||||
w_txi = bytearray_with_cap(
|
w_txi = bytearray_with_cap(
|
||||||
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
||||||
if i_sign == 0: # serializing first input => prepend meta
|
if i_sign == 0: # serializing first input => prepend headers
|
||||||
write_uint32(w_txi, tx.version)
|
write_bytes(w_txi, get_tx_header(tx, True))
|
||||||
write_varint(w_txi, 0x00) # segwit witness marker
|
|
||||||
write_varint(w_txi, 0x01) # segwit witness flag
|
|
||||||
write_varint(w_txi, tx.inputs_count)
|
|
||||||
write_tx_input(w_txi, txi_sign)
|
write_tx_input(w_txi, txi_sign)
|
||||||
tx_ser.serialized_tx = w_txi
|
tx_ser.serialized_tx = w_txi
|
||||||
|
|
||||||
@ -299,9 +185,8 @@ async def sign_tx(tx: SignTx, root):
|
|||||||
txi_sign, key_sign_pub, signature)
|
txi_sign, key_sign_pub, signature)
|
||||||
w_txi_sign = bytearray_with_cap(
|
w_txi_sign = bytearray_with_cap(
|
||||||
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
||||||
if i_sign == 0: # serializing first input => prepend tx version and inputs count
|
if i_sign == 0: # serializing first input => prepend headers
|
||||||
write_uint32(w_txi_sign, tx.version)
|
write_bytes(w_txi_sign, get_tx_header(tx))
|
||||||
write_varint(w_txi_sign, tx.inputs_count)
|
|
||||||
write_tx_input(w_txi_sign, txi_sign)
|
write_tx_input(w_txi_sign, txi_sign)
|
||||||
tx_ser.serialized_tx = w_txi_sign
|
tx_ser.serialized_tx = w_txi_sign
|
||||||
|
|
||||||
@ -389,6 +274,26 @@ def estimate_tx_size(inputs, outputs):
|
|||||||
return 10 + inputs * 149 + outputs * 35
|
return 10 + inputs * 149 + outputs * 35
|
||||||
|
|
||||||
|
|
||||||
|
# TX Helpers
|
||||||
|
# ===
|
||||||
|
|
||||||
|
def get_tx_header(tx: SignTx, segwit=False):
|
||||||
|
w_txi = bytearray()
|
||||||
|
write_uint32(w_txi, tx.version)
|
||||||
|
if segwit:
|
||||||
|
write_varint(w_txi, 0x00) # segwit witness marker
|
||||||
|
write_varint(w_txi, 0x01) # segwit witness flag
|
||||||
|
write_varint(w_txi, tx.inputs_count)
|
||||||
|
return w_txi
|
||||||
|
|
||||||
|
|
||||||
|
def get_p2wpkh_witness(signature: bytes, pubkey: bytes):
|
||||||
|
w = bytearray_with_cap(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
|
||||||
|
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
|
||||||
|
append_signature_and_pubkey(w, pubkey, signature)
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
# TX Outputs
|
# TX Outputs
|
||||||
# ===
|
# ===
|
||||||
|
|
||||||
@ -532,13 +437,6 @@ def script_spendaddress_new(pubkey: bytes, signature: bytes) -> bytearray:
|
|||||||
return w
|
return w
|
||||||
|
|
||||||
|
|
||||||
def get_p2wpkh_witness(signature: bytes, pubkey: bytes):
|
|
||||||
w = bytearray_with_cap(1 + 5 + len(signature) + 1 + 5 + len(pubkey))
|
|
||||||
write_varint(w, 0x02) # num of segwit items, in P2WPKH it's always 2
|
|
||||||
append_signature_and_pubkey(w, pubkey, signature)
|
|
||||||
return w
|
|
||||||
|
|
||||||
|
|
||||||
def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes) -> bytearray:
|
def append_signature_and_pubkey(w: bytearray, pubkey: bytes, signature: bytes) -> bytearray:
|
||||||
write_op_push(w, len(signature) + 1)
|
write_op_push(w, len(signature) + 1)
|
||||||
write_bytes(w, signature)
|
write_bytes(w, signature)
|
||||||
|
@ -44,7 +44,7 @@ class TestSignSegwitTx(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
out2 = TxOutputType(
|
out2 = TxOutputType(
|
||||||
address='2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX',
|
address='2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX',
|
||||||
script_type=OutputScriptType.PAYTOSCRIPTHASH, # todo
|
script_type=OutputScriptType.PAYTOSCRIPTHASH, # todo!
|
||||||
amount=123456789 - 11000 - 12300000,
|
amount=123456789 - 11000 - 12300000,
|
||||||
address_n=None,
|
address_n=None,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user