Merge pull request #411 from trezor/keychain

Introduce Keychain API
pull/25/head
Jan Pochyla 6 years ago committed by GitHub
commit 4c46a055ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -422,7 +422,16 @@ STATIC mp_obj_t mod_trezorcrypto_HDNode_ethereum_pubkeyhash(mp_obj_t self) {
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorcrypto_HDNode_ethereum_pubkeyhash_obj, mod_trezorcrypto_HDNode_ethereum_pubkeyhash);
STATIC mp_obj_t mod_trezorcrypto_HDNode___del__(mp_obj_t self) {
mp_obj_HDNode_t *o = MP_OBJ_TO_PTR(self);
o->fingerprint = 0;
memzero(&o->hdnode, sizeof(o->hdnode));
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_trezorcrypto_HDNode___del___obj, mod_trezorcrypto_HDNode___del__);
STATIC const mp_rom_map_elem_t mod_trezorcrypto_HDNode_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mod_trezorcrypto_HDNode___del___obj) },
{ MP_ROM_QSTR(MP_QSTR_derive), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_obj) },
{ MP_ROM_QSTR(MP_QSTR_derive_cardano), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_cardano_obj) },
{ MP_ROM_QSTR(MP_QSTR_derive_path), MP_ROM_PTR(&mod_trezorcrypto_HDNode_derive_path_obj) },
@ -539,9 +548,7 @@ STATIC mp_obj_t mod_trezorcrypto_bip32_from_mnemonic_cardano(mp_obj_t mnemonic,
return MP_OBJ_FROM_PTR(o);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_bip32_from_mnemonic_cardano_obj,
mod_trezorcrypto_bip32_from_mnemonic_cardano);
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_bip32_from_mnemonic_cardano_obj, mod_trezorcrypto_bip32_from_mnemonic_cardano);
STATIC const mp_rom_map_elem_t mod_trezorcrypto_bip32_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_bip32) },

@ -1,6 +1,10 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
SEED_NAMESPACE = [[HARDENED | 44, HARDENED | 1815]]
def boot():
wire.add(MessageType.CardanoGetAddress, __name__, "get_address")

@ -1,8 +1,27 @@
from trezor.crypto import base58, crc, hashlib
from . import cbor
from apps.cardano import cbor
from apps.common import HARDENED
from apps.common.seed import remove_ed25519_prefix
from apps.common import HARDENED, seed
def derive_address_and_node(keychain, path: list):
node = keychain.derive(path)
address_payload = None
address_attributes = {}
address_root = _get_address_root(node, address_payload)
address_type = 0
address_data = [address_root, address_attributes, address_type]
address_data_encoded = cbor.encode(address_data)
address = base58.encode(
cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
)
return (address, node)
def validate_full_path(path: list) -> bool:
@ -36,31 +55,9 @@ def _address_hash(data) -> bytes:
def _get_address_root(node, payload):
extpubkey = seed.remove_ed25519_prefix(node.public_key()) + node.chain_code()
extpubkey = remove_ed25519_prefix(node.public_key()) + node.chain_code()
if payload:
payload = {1: cbor.encode(payload)}
else:
payload = {}
return _address_hash([0, [0, extpubkey], payload])
def derive_address_and_node(root_node, path: list):
derived_node = root_node.clone()
address_payload = None
address_attributes = {}
for indice in path:
derived_node.derive_cardano(indice)
address_root = _get_address_root(derived_node, address_payload)
address_type = 0
address_data = [address_root, address_attributes, address_type]
address_data_encoded = cbor.encode(address_data)
address = base58.encode(
cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
)
return (address, derived_node)

@ -1,28 +1,23 @@
from trezor import log, ui, wire
from trezor.crypto import bip32
from trezor.messages.CardanoAddress import CardanoAddress
from .address import derive_address_and_node, validate_full_path
from .layout import confirm_with_pagination
from apps.common import paths, seed, storage
from apps.cardano import seed
from apps.cardano.address import derive_address_and_node, validate_full_path
from apps.cardano.layout import confirm_with_pagination
from apps.common import paths
async def get_address(ctx, msg):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
keychain = await seed.get_keychain(ctx)
mnemonic = storage.get_mnemonic()
passphrase = await seed._get_cached_passphrase(ctx)
root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
try:
address, _ = derive_address_and_node(root_node, msg.address_n)
address, _ = derive_address_and_node(keychain, msg.address_n)
except ValueError as e:
if __debug__:
log.exception(__name__, e)
raise wire.ProcessError("Deriving address failed")
mnemonic = None
root_node = None
if msg.show_display:
if not await confirm_with_pagination(

@ -1,42 +1,38 @@
from ubinascii import hexlify
from trezor import log, wire
from trezor.crypto import bip32
from trezor.messages.CardanoPublicKey import CardanoPublicKey
from trezor.messages.HDNodeType import HDNodeType
from .address import derive_address_and_node
from apps.common import layout, paths, seed, storage
from apps.cardano import seed
from apps.cardano.address import derive_address_and_node
from apps.common import layout, paths
from apps.common.seed import remove_ed25519_prefix
async def get_public_key(ctx, msg):
keychain = await seed.get_keychain(ctx)
await paths.validate_path(
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815
)
mnemonic = storage.get_mnemonic()
passphrase = await seed._get_cached_passphrase(ctx)
root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
try:
key = _get_public_key(root_node, msg.address_n)
key = _get_public_key(keychain, msg.address_n)
except ValueError as e:
if __debug__:
log.exception(__name__, e)
raise wire.ProcessError("Deriving public key failed")
mnemonic = None
root_node = None
if msg.show_display:
await layout.show_pubkey(ctx, key.node.public_key)
return key
def _get_public_key(root_node, derivation_path: list):
_, node = derive_address_and_node(root_node, derivation_path)
def _get_public_key(keychain, derivation_path: list):
_, node = derive_address_and_node(keychain, derivation_path)
public_key = hexlify(seed.remove_ed25519_prefix(node.public_key())).decode()
public_key = hexlify(remove_ed25519_prefix(node.public_key())).decode()
chain_code = hexlify(node.chain_code()).decode()
xpub_key = public_key + chain_code
@ -45,7 +41,7 @@ def _get_public_key(root_node, derivation_path: list):
child_num=node.child_num(),
fingerprint=node.fingerprint(),
chain_code=node.chain_code(),
public_key=seed.remove_ed25519_prefix(node.public_key()),
public_key=remove_ed25519_prefix(node.public_key()),
)
return CardanoPublicKey(node=node_type, xpub=xpub_key)

@ -0,0 +1,43 @@
from trezor import wire
from trezor.crypto import bip32
from apps.cardano import SEED_NAMESPACE
from apps.common import cache, storage
from apps.common.request_passphrase import protect_by_passphrase
class Keychain:
def __init__(self, path: list, root: bip32.HDNode):
self.path = path
self.root = root
def derive(self, node_path: list) -> bip32.HDNode:
# check we are in the cardano namespace
prefix = node_path[: len(self.path)]
suffix = node_path[len(self.path) :]
if prefix != self.path:
raise wire.DataError("Forbidden key path")
# derive child node from the root
node = self.root.clone()
for i in suffix:
node.derive_cardano(i)
return node
async def get_keychain(ctx: wire.Context) -> Keychain:
if not storage.is_initialized():
raise wire.ProcessError("Device is not initialized")
# derive the root node from mnemonic and passphrase
passphrase = cache.get_passphrase()
if passphrase is None:
passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase)
root = bip32.from_mnemonic_cardano(storage.get_mnemonic(), passphrase)
# derive the namespaced root node
for i in SEED_NAMESPACE[0]:
root.derive_cardano(i)
keychain = Keychain(SEED_NAMESPACE[0], root)
return keychain

@ -1,18 +1,17 @@
from trezor import log, ui, wire
from trezor.crypto import base58, bip32, hashlib
from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from trezor.ui.text import BR
from .address import derive_address_and_node, validate_full_path
from .layout import confirm_with_pagination, progress
from apps.cardano import cbor
from apps.common import seed, storage
from apps.cardano import cbor, seed
from apps.cardano.address import derive_address_and_node, validate_full_path
from apps.cardano.layout import confirm_with_pagination, progress
from apps.common.layout import address_n_to_str, split_address
from apps.common.paths import validate_path
from apps.common.seed import remove_ed25519_prefix
from apps.homescreen.homescreen import display_homescreen
@ -80,9 +79,7 @@ async def request_transaction(ctx, tx_req: CardanoTxRequest, index: int):
async def sign_tx(ctx, msg):
mnemonic = storage.get_mnemonic()
passphrase = await seed._get_cached_passphrase(ctx)
root_node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
keychain = await seed.get_keychain(ctx)
progress.init(msg.transactions_count, "Loading data")
@ -103,7 +100,7 @@ async def sign_tx(ctx, msg):
# sign the transaction bundle and prepare the result
transaction = Transaction(
msg.inputs, msg.outputs, transactions, root_node, msg.network
msg.inputs, msg.outputs, transactions, keychain, msg.network
)
tx_body, tx_hash = transaction.serialise_tx()
tx = CardanoSignedTx(tx_body=tx_body, tx_hash=tx_hash)
@ -135,12 +132,12 @@ def _micro_ada_to_ada(amount: float) -> float:
class Transaction:
def __init__(
self, inputs: list, outputs: list, transactions: list, root_node, network: int
self, inputs: list, outputs: list, transactions: list, keychain, network: int
):
self.inputs = inputs
self.outputs = outputs
self.transactions = transactions
self.root_node = root_node
self.keychain = keychain
# attributes have to be always empty in current Cardano
self.attributes = {}
if network == 1:
@ -170,7 +167,7 @@ class Transaction:
nodes = []
for input in self.inputs:
_, node = derive_address_and_node(self.root_node, input.address_n)
_, node = derive_address_and_node(self.keychain, input.address_n)
nodes.append(node)
for index, output_index in enumerate(output_indexes):
@ -198,7 +195,7 @@ class Transaction:
for output in self.outputs:
if output.address_n:
address, _ = derive_address_and_node(self.root_node, output.address_n)
address, _ = derive_address_and_node(self.keychain, output.address_n)
change_addresses.append(address)
change_derivation_paths.append(output.address_n)
change_coins.append(output.amount)
@ -225,7 +222,7 @@ class Transaction:
node.private_key(), node.private_key_ext(), message
)
extended_public_key = (
seed.remove_ed25519_prefix(node.public_key()) + node.chain_code()
remove_ed25519_prefix(node.public_key()) + node.chain_code()
)
witnesses.append(
[

@ -4,41 +4,75 @@ from trezor.crypto import bip32, bip39
from apps.common import cache, storage
from apps.common.request_passphrase import protect_by_passphrase
_DEFAULT_CURVE = "secp256k1"
allow = list
async def derive_node(
ctx: wire.Context, path: list, curve_name: str = _DEFAULT_CURVE
) -> bip32.HDNode:
seed = await _get_cached_seed(ctx)
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)
return node
class Keychain:
"""
Keychain provides an API for deriving HD keys from previously allowed
key-spaces.
"""
def __init__(self, seed: bytes, namespaces: list):
self.seed = seed
self.namespaces = namespaces
self.roots = [None] * len(namespaces)
def __del__(self):
for root in self.roots:
if root is not None:
root.__del__()
del self.roots
del self.seed
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node index
root_index = 0
for curve, *path in self.namespaces:
prefix = node_path[: len(path)]
suffix = node_path[len(path) :]
if curve == curve_name and path == prefix:
break
root_index += 1
else:
raise wire.DataError("Forbidden key path")
async def _get_cached_seed(ctx: wire.Context) -> bytes:
# create the root node if not cached
root = self.roots[root_index]
if root is None:
root = bip32.from_seed(self.seed, curve_name)
root.derive_path(path)
self.roots[root_index] = root
# derive child node from the root
node = root.clone()
node.derive_path(suffix)
return node
async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
if not storage.is_initialized():
raise wire.ProcessError("Device is not initialized")
if cache.get_seed() is None:
passphrase = await _get_cached_passphrase(ctx)
seed = cache.get_seed()
if seed is None:
# derive seed from mnemonic and passphrase
passphrase = cache.get_passphrase()
if passphrase is None:
passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase)
seed = bip39.seed(storage.get_mnemonic(), passphrase)
cache.set_seed(seed)
return cache.get_seed()
async def _get_cached_passphrase(ctx: wire.Context) -> str:
if cache.get_passphrase() is None:
passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase)
return cache.get_passphrase()
keychain = Keychain(seed, namespaces)
return keychain
def derive_node_without_passphrase(
path: list, curve_name: str = _DEFAULT_CURVE
path: list, curve_name: str = "secp256k1"
) -> bip32.HDNode:
if not storage.is_initialized():
raise Exception("Device is not initialized")
seed = bip39.seed(storage.get_mnemonic(), "")
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)

@ -3,7 +3,7 @@ from ubinascii import hexlify
from trezor.crypto.hashlib import blake256, sha256
from trezor.utils import HashWriter
from apps.wallet.sign_tx.signing import write_varint
from apps.wallet.sign_tx.writers import write_varint
def message_digest(coin, message):

@ -1,9 +1,12 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
ns = [["secp256k1", HARDENED | 44, HARDENED | 60]]
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns)
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns)
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")

@ -1,20 +1,17 @@
from .address import ethereum_address_hex, validate_full_path
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages.EthereumAddress import EthereumAddress
from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ethereum import networks
from apps.ethereum.address import ethereum_address_hex, validate_full_path
async def get_address(ctx, msg):
from trezor.messages.EthereumAddress import EthereumAddress
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from apps.common import seed
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n)
node = keychain.derive(msg.address_n)
seckey = node.private_key()
public_key = secp256k1.publickey(seckey, False) # uncompressed
address = sha3_256(public_key[1:], keccak=True).digest()[12:]

@ -82,14 +82,14 @@ NETWORKS = [
NetworkInfo(
chain_id=30,
slip44=137,
shortcut="RSK",
shortcut="RBTC",
name="RSK",
rskip60=True,
),
NetworkInfo(
chain_id=31,
slip44=37310,
shortcut="tRSK",
shortcut="tRBTC",
name="RSK Testnet",
rskip60=True,
),

@ -4,11 +4,10 @@ from trezor.messages.EthereumMessageSignature import EthereumMessageSignature
from trezor.ui.text import Text
from trezor.utils import HashWriter
from .address import validate_full_path
from apps.common import paths, seed
from apps.common import paths
from apps.common.confirm import require_confirm
from apps.common.signverify import split_message
from apps.ethereum.address import validate_full_path
def message_digest(message):
@ -20,12 +19,11 @@ def message_digest(message):
return h.get_digest()
async def sign_message(ctx, msg):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = await seed.derive_node(ctx, msg.address_n)
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(),
message_digest(msg.message),

@ -7,10 +7,9 @@ from trezor.messages.EthereumTxRequest import EthereumTxRequest
from trezor.messages.MessageType import EthereumTxAck
from trezor.utils import HashWriter
from .address import validate_full_path
from apps.common import paths, seed
from apps.common import paths
from apps.ethereum import tokens
from apps.ethereum.address import validate_full_path
from apps.ethereum.layout import (
require_confirm_data,
require_confirm_fee,
@ -21,7 +20,7 @@ from apps.ethereum.layout import (
MAX_CHAIN_ID = 2147483629
async def sign_tx(ctx, msg):
async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg)
check(msg)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
@ -91,7 +90,9 @@ async def sign_tx(ctx, msg):
sha.extend(rlp.encode(0))
digest = sha.get_digest()
return await send_signature(ctx, msg, digest)
result = sign_digest(msg, keychain, digest)
return result
def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
@ -130,9 +131,8 @@ async def send_request_chunk(ctx, data_left: int):
return await ctx.call(req, EthereumTxAck)
async def send_signature(ctx, msg: EthereumSignTx, digest):
node = await seed.derive_node(ctx, msg.address_n)
def sign_digest(msg: EthereumSignTx, keychain, digest):
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
)

@ -1,10 +1,13 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.LiskGetAddress, __name__, "get_address")
wire.add(MessageType.LiskSignMessage, __name__, "sign_message")
ns = [["ed25519", HARDENED | 44, HARDENED | 134]]
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns)
wire.add(MessageType.LiskSignTx, __name__, "sign_tx", ns)
wire.add(MessageType.LiskSignMessage, __name__, "sign_message", ns)
wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message")
wire.add(MessageType.LiskSignTx, __name__, "sign_tx")

@ -2,14 +2,14 @@ from trezor.messages.LiskAddress import LiskAddress
from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path
from apps.common import paths, seed
from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg):
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker
address = get_address_from_public_key(pubkey)

@ -2,13 +2,13 @@ from trezor.messages.LiskPublicKey import LiskPublicKey
from .helpers import LISK_CURVE, validate_full_path
from apps.common import layout, paths, seed
from apps.common import layout, paths
async def get_public_key(ctx, msg):
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker

@ -6,10 +6,10 @@ from trezor.utils import HashWriter
from .helpers import LISK_CURVE, validate_full_path
from apps.common import paths, seed
from apps.common import paths
from apps.common.confirm import require_confirm
from apps.common.signverify import split_message
from apps.wallet.sign_tx.signing import write_varint
from apps.wallet.sign_tx.writers import write_varint
def message_digest(message):
@ -22,11 +22,11 @@ def message_digest(message):
return sha256(h.get_digest()).digest()
async def sign_message(ctx, msg):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE)
node = keychain.derive(msg.address_n, LISK_CURVE)
seckey = node.private_key()
pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker

@ -7,16 +7,19 @@ from trezor.messages import LiskTransactionType
from trezor.messages.LiskSignedTx import LiskSignedTx
from trezor.utils import HashWriter
from . import layout
from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path
from apps.common import paths
from apps.lisk import layout
from apps.lisk.helpers import (
LISK_CURVE,
get_address_from_public_key,
validate_full_path,
)
from apps.common import paths, seed
async def sign_tx(ctx, msg):
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
pubkey, seckey = await _get_keys(ctx, msg)
pubkey, seckey = _get_keys(keychain, msg)
transaction = _update_raw_tx(msg.transaction, pubkey)
try:
@ -37,8 +40,8 @@ async def sign_tx(ctx, msg):
return LiskSignedTx(signature=signature)
async def _get_keys(ctx, msg):
node = await seed.derive_node(ctx, msg.address_n, LISK_CURVE)
def _get_keys(keychain, msg):
node = keychain.derive(msg.address_n, LISK_CURVE)
seckey = node.private_key()
pubkey = node.public_key()

@ -1,12 +1,20 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.MoneroGetAddress, __name__, "get_address")
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only")
wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx")
wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync")
ns = [
["secp256k1", HARDENED | 44, HARDENED | 128],
["ed25519", HARDENED | 44, HARDENED | 128],
]
wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns)
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns)
wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx", ns)
wire.add(
MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync", ns
)
if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"):
wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag")

@ -5,10 +5,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.monero import misc
async def get_address(ctx, msg):
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
creds = await misc.get_creds(ctx, msg.address_n, msg.network_type)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
if msg.show_display:
desc = address_n_to_str(msg.address_n)

@ -7,12 +7,12 @@ from apps.monero.layout import confirms
from apps.monero.xmr import crypto
async def get_watch_only(ctx, msg: MoneroGetWatchKey):
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await confirms.require_confirm_watchkey(ctx)
creds = await misc.get_creds(ctx, msg.address_n, msg.network_type)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
address = creds.address
watch_key = crypto.encodeint(creds.view_key_private)

@ -14,10 +14,10 @@ from apps.monero.xmr import crypto, key_image, monero
from apps.monero.xmr.crypto import chacha_poly
async def key_image_sync(ctx, msg):
async def key_image_sync(ctx, msg, keychain):
state = KeyImageSync()
res = await _init_step(state, ctx, msg)
res = await _init_step(state, ctx, msg, keychain)
while True:
msg = await ctx.call(
res,
@ -46,10 +46,10 @@ class KeyImageSync:
self.hasher = crypto.get_keccak()
async def _init_step(s, ctx, msg):
async def _init_step(s, ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
s.creds = await misc.get_creds(ctx, msg.address_n, msg.network_type)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
await confirms.require_confirm_keyimage_sync(ctx)

@ -1,8 +1,7 @@
from apps.common import HARDENED
async def get_creds(ctx, address_n=None, network_type=None):
from apps.common import seed
def get_creds(keychain, address_n=None, network_type=None):
from apps.monero.xmr import crypto, monero
from apps.monero.xmr.credentials import AccountCreds
@ -12,7 +11,7 @@ async def get_creds(ctx, address_n=None, network_type=None):
curve = "ed25519"
else:
curve = "secp256k1"
node = await seed.derive_node(ctx, address_n, curve)
node = keychain.derive(address_n, curve)
if use_slip0010:
key_seed = node.private_key()

@ -6,7 +6,7 @@ from trezor.messages import MessageType
from apps.monero.signing.state import State
async def sign_tx(ctx, received_msg):
async def sign_tx(ctx, received_msg, keychain):
state = State(ctx)
mods = utils.unimport_begin()
@ -18,7 +18,7 @@ async def sign_tx(ctx, received_msg):
gc.collect()
gc.threshold(gc.mem_free() // 4 + gc.mem_alloc())
result_msg, accept_msgs = await sign_tx_dispatch(state, received_msg)
result_msg, accept_msgs = await sign_tx_dispatch(state, received_msg, keychain)
if accept_msgs is None:
break
@ -32,13 +32,13 @@ async def sign_tx(ctx, received_msg):
return result_msg
async def sign_tx_dispatch(state, msg):
async def sign_tx_dispatch(state, msg, keychain):
if msg.MESSAGE_WIRE_TYPE == MessageType.MoneroTransactionInitRequest:
from apps.monero.signing import step_01_init_transaction
return (
await step_01_init_transaction.init_transaction(
state, msg.address_n, msg.network_type, msg.tsx_data
state, msg.address_n, msg.network_type, msg.tsx_data, keychain
),
(MessageType.MoneroTransactionSetInputRequest,),
)

@ -16,14 +16,18 @@ if False:
async def init_transaction(
state: State, address_n: list, network_type: int, tsx_data: MoneroTransactionData
state: State,
address_n: list,
network_type: int,
tsx_data: MoneroTransactionData,
keychain,
):
from apps.monero.signing import offloading_keys
from apps.common import paths
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
state.creds = await misc.get_creds(state.ctx, address_n, network_type)
state.creds = misc.get_creds(keychain, address_n, network_type)
state.fee = state.fee if state.fee > 0 else 0
state.tx_priv = crypto.random_scalar()
state.tx_pub = crypto.scalarmult_base(state.tx_priv)

@ -1,7 +1,13 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.NEMGetAddress, __name__, "get_address")
wire.add(MessageType.NEMSignTx, __name__, "sign_tx")
ns = [
["ed25519-keccak", HARDENED | 44, HARDENED | 43],
["ed25519-keccak", HARDENED | 44, HARDENED | 1],
]
wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns)
wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns)

@ -3,16 +3,15 @@ from trezor.messages.NEMAddress import NEMAddress
from .helpers import NEM_CURVE, check_path, get_network_str
from .validators import validate_network
from apps.common import seed
from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.common.paths import validate_path
async def get_address(ctx, msg):
async def get_address(ctx, msg, keychain):
network = validate_network(msg.network)
await validate_path(ctx, check_path, path=msg.address_n, network=msg.network)
await validate_path(ctx, check_path, path=msg.address_n, network=network)
node = await seed.derive_node(ctx, msg.address_n, NEM_CURVE)
node = keychain.derive(msg.address_n, NEM_CURVE)
address = node.nem_address(network)
if msg.show_display:

@ -2,21 +2,21 @@ from trezor.crypto.curve import ed25519
from trezor.messages.NEMSignedTx import NEMSignedTx
from trezor.messages.NEMSignTx import NEMSignTx
from . import mosaic, multisig, namespace, transfer
from .helpers import NEM_CURVE, NEM_HASH_ALG, check_path
from .validators import validate
from apps.common import seed
from apps.common.paths import validate_path
from apps.nem import mosaic, multisig, namespace, transfer
from apps.nem.helpers import NEM_CURVE, NEM_HASH_ALG, check_path
from apps.nem.validators import validate
async def sign_tx(ctx, msg: NEMSignTx):
async def sign_tx(ctx, msg: NEMSignTx, keychain):
validate(msg)
await validate_path(
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network
)
node = await seed.derive_node(ctx, msg.transaction.address_n, NEM_CURVE)
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)
if msg.multisig:
public_key = msg.multisig.signer

@ -1,7 +1,10 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.RippleGetAddress, __name__, "get_address")
wire.add(MessageType.RippleSignTx, __name__, "sign_tx")
ns = [["secp256k1", HARDENED | 44, HARDENED | 144]]
wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns)
wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns)

@ -1,16 +1,15 @@
from trezor.messages.RippleAddress import RippleAddress
from trezor.messages.RippleGetAddress import RippleGetAddress
from . import helpers
from apps.common import paths, seed
from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ripple import helpers
async def get_address(ctx, msg: RippleGetAddress):
async def get_address(ctx, msg: RippleGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = node.public_key()
address = helpers.address_from_public_key(pubkey)

@ -5,17 +5,17 @@ from trezor.messages.RippleSignedTx import RippleSignedTx
from trezor.messages.RippleSignTx import RippleSignTx
from trezor.wire import ProcessError
from . import helpers, layout
from .serialize import serialize
from apps.common import paths
from apps.ripple import helpers, layout
from apps.ripple.serialize import serialize
from apps.common import paths, seed
async def sign_tx(ctx, msg: RippleSignTx):
async def sign_tx(ctx, msg: RippleSignTx, keychain):
validate(msg)
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n)
node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key())
set_canonical_flag(msg)

@ -1,7 +1,10 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.StellarGetAddress, __name__, "get_address")
wire.add(MessageType.StellarSignTx, __name__, "sign_tx")
ns = [["ed25519", HARDENED | 44, HARDENED | 148]]
wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns)
wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns)

@ -6,10 +6,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.stellar import helpers
async def get_address(ctx, msg: StellarGetAddress):
async def get_address(ctx, msg: StellarGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, helpers.STELLAR_CURVE)
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())
address = helpers.address_from_public_key(pubkey)

@ -12,15 +12,15 @@ from apps.stellar import consts, helpers, layout, writers
from apps.stellar.operations import process_operation
async def sign_tx(ctx, msg: StellarSignTx):
if msg.num_operations == 0:
raise ProcessError("Stellar: At least one operation is required")
async def sign_tx(ctx, msg: StellarSignTx, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, consts.STELLAR_CURVE)
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())
if msg.num_operations == 0:
raise ProcessError("Stellar: At least one operation is required")
w = bytearray()
await _init(ctx, w, pubkey, msg)
_timebounds(w, msg.timebounds_start, msg.timebounds_end)

@ -1,8 +1,11 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common import HARDENED
def boot():
wire.add(MessageType.TezosGetAddress, __name__, "get_address")
wire.add(MessageType.TezosSignTx, __name__, "sign_tx")
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key")
ns = [["ed25519", HARDENED | 44, HARDENED | 1729]]
wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns)
wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns)
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key", ns)

@ -6,9 +6,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.tezos import helpers
async def get_address(ctx, msg):
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
pk = seed.remove_ed25519_prefix(node.public_key())
pkh = hashlib.blake2b(pk, outlen=20).digest()

@ -9,10 +9,10 @@ from apps.common.confirm import require_confirm
from apps.tezos import helpers
async def get_public_key(ctx, msg):
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
pk = seed.remove_ed25519_prefix(node.public_key())
pk_prefixed = helpers.base58_encode_check(pk, prefix=helpers.TEZOS_PUBLICKEY_PREFIX)

@ -4,14 +4,15 @@ from trezor.crypto.curve import ed25519
from trezor.messages import TezosContractType
from trezor.messages.TezosSignedTx import TezosSignedTx
from apps.common import paths, seed
from apps.common import paths
from apps.common.writers import write_bytes, write_uint8
from apps.tezos import helpers, layout
async def sign_tx(ctx, msg):
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
node = await seed.derive_node(ctx, msg.address_n, helpers.TEZOS_CURVE)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
if msg.transaction is not None:
to = _get_address_from_contract(msg.transaction.destination)

@ -3,12 +3,22 @@ from trezor.messages import MessageType
def boot():
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
wire.add(MessageType.GetAddress, __name__, "get_address")
ns = [
["curve25519"],
["ed25519"],
["ed25519-keccak"],
["nist256p1"],
["secp256k1"],
["secp256k1-decred"],
["secp256k1-groestl"],
["secp256k1-smart"],
]
wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.GetAddress, __name__, "get_address", ns)
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
wire.add(MessageType.SignTx, __name__, "sign_tx")
wire.add(MessageType.SignMessage, __name__, "sign_message")
wire.add(MessageType.SignTx, __name__, "sign_tx", ns)
wire.add(MessageType.SignMessage, __name__, "sign_message", ns)
wire.add(MessageType.VerifyMessage, __name__, "verify_message")
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")
wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns)
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns)
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns)

@ -4,11 +4,10 @@ from trezor.crypto.hashlib import sha512
from trezor.messages.CipheredKeyValue import CipheredKeyValue
from trezor.ui.text import Text
from apps.common import seed
from apps.common.confirm import require_confirm
async def cipher_key_value(ctx, msg):
async def cipher_key_value(ctx, msg, keychain):
if len(msg.value) % 16 > 0:
raise wire.DataError("Value length must be a multiple of 16")
@ -23,7 +22,7 @@ async def cipher_key_value(ctx, msg):
text.normal(msg.key)
await require_confirm(ctx, text)
node = await seed.derive_node(ctx, msg.address_n)
node = keychain.derive(msg.address_n)
value = compute_cipher_key_value(msg, node.private_key())
return CipheredKeyValue(value=value)

@ -1,13 +1,13 @@
from trezor.messages import InputScriptType
from trezor.messages.Address import Address
from apps.common import coins, seed
from apps.common import coins
from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.common.paths import validate_path
from apps.wallet.sign_tx import addresses
async def get_address(ctx, msg):
async def get_address(ctx, msg, keychain):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
@ -19,7 +19,7 @@ async def get_address(ctx, msg):
script_type=msg.script_type,
)
node = await seed.derive_node(ctx, msg.address_n, curve_name=coin.curve_name)
node = keychain.derive(msg.address_n, coin.curve_name)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig)
address_short = addresses.address_short(coin, address)

@ -5,7 +5,7 @@ from trezor.messages.ECDHSessionKey import ECDHSessionKey
from trezor.ui.text import Text
from trezor.utils import chunks
from apps.common import HARDENED, seed
from apps.common import HARDENED
from apps.common.confirm import require_confirm
from apps.wallet.sign_identity import (
serialize_identity,
@ -13,7 +13,7 @@ from apps.wallet.sign_identity import (
)
async def get_ecdh_session_key(ctx, msg):
async def get_ecdh_session_key(ctx, msg, keychain):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1"
@ -22,7 +22,7 @@ async def get_ecdh_session_key(ctx, msg):
await require_confirm_ecdh_session_key(ctx, msg.identity)
address_n = get_ecdh_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
node = keychain.derive(address_n, msg.ecdsa_curve_name)
session_key = ecdh(
seckey=node.private_key(),

@ -7,7 +7,6 @@ from apps.common.confirm import require_confirm
async def get_entropy(ctx, msg):
text = Text("Confirm entropy")
text.bold("Do you really want", "to send entropy?")
text.normal("Continue only if you", "know what you are doing!")

@ -3,18 +3,16 @@ from trezor.messages import InputScriptType
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.PublicKey import PublicKey
from apps.common import coins, layout, seed
from apps.common import coins, layout
async def get_public_key(ctx, msg):
async def get_public_key(ctx, msg, keychain):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
curve_name = msg.ecdsa_curve_name or coin.curve_name
script_type = msg.script_type or InputScriptType.SPENDADDRESS
curve_name = msg.ecdsa_curve_name
if not curve_name:
curve_name = coin.curve_name
node = await seed.derive_node(ctx, msg.address_n, curve_name=curve_name)
node = keychain.derive(msg.address_n, curve_name=curve_name)
if (
script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]

@ -6,11 +6,11 @@ from trezor.messages.SignedIdentity import SignedIdentity
from trezor.ui.text import Text
from trezor.utils import chunks
from apps.common import HARDENED, coins, seed
from apps.common import HARDENED, coins
from apps.common.confirm import require_confirm
async def sign_identity(ctx, msg):
async def sign_identity(ctx, msg, keychain):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1"
@ -19,7 +19,7 @@ async def sign_identity(ctx, msg):
await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual)
address_n = get_identity_path(identity, msg.identity.index or 0)
node = await seed.derive_node(ctx, address_n, msg.ecdsa_curve_name)
node = keychain.derive(address_n, msg.ecdsa_curve_name)
coin = coins.by_name("Bitcoin")
if msg.ecdsa_curve_name == "secp256k1":

@ -4,14 +4,14 @@ from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPEN
from trezor.messages.MessageSignature import MessageSignature
from trezor.ui.text import Text
from apps.common import coins, seed
from apps.common import coins
from apps.common.confirm import require_confirm
from apps.common.paths import validate_path
from apps.common.signverify import message_digest, split_message
from apps.wallet.sign_tx.addresses import get_address, validate_full_path
async def sign_message(ctx, msg):
async def sign_message(ctx, msg, keychain):
message = msg.message
address_n = msg.address_n
coin_name = msg.coin_name or "Bitcoin"
@ -19,7 +19,6 @@ async def sign_message(ctx, msg):
coin = coins.by_name(coin_name)
await require_confirm_sign_message(ctx, message)
await validate_path(
ctx,
validate_full_path,
@ -29,7 +28,7 @@ async def sign_message(ctx, msg):
validate_script_type=False,
)
node = await seed.derive_node(ctx, address_n, curve_name=coin.curve_name)
node = keychain.derive(address_n, coin.curve_name)
seckey = node.private_key()
address = get_address(script_type, coin, node)

@ -3,53 +3,51 @@ from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths, seed
from apps.wallet.sign_tx.helpers import (
UiConfirmFeeOverThreshold,
UiConfirmForeignAddress,
UiConfirmOutput,
UiConfirmTotal,
from apps.common import paths
from apps.wallet.sign_tx import (
addresses,
helpers,
layout,
multisig,
progress,
scripts,
segwit_bip143,
signing,
)
@ui.layout
async def sign_tx(ctx, msg):
from apps.wallet.sign_tx import layout, progress, signing
async def sign_tx(ctx, msg, keychain):
signer = signing.sign_tx(msg, keychain)
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
# TODO: rework this so we don't have to pass root to signing.sign_tx
root = await seed.derive_node(ctx, [], curve_name=coin.curve_name)
signer = signing.sign_tx(msg, root)
res = None
while True:
try:
req = signer.send(res)
except signing.SigningError as e:
raise wire.Error(*e.args)
except signing.MultisigError as e:
except multisig.MultisigError as e:
raise wire.Error(*e.args)
except signing.AddressError as e:
except addresses.AddressError as e:
raise wire.Error(*e.args)
except signing.ScriptsError as e:
except scripts.ScriptsError as e:
raise wire.Error(*e.args)
except signing.Bip143Error as e:
except segwit_bip143.Bip143Error as e:
raise wire.Error(*e.args)
if isinstance(req, TxRequest):
if req.request_type == TXFINISHED:
break
res = await ctx.call(req, TxAck)
elif isinstance(req, UiConfirmOutput):
elif isinstance(req, helpers.UiConfirmOutput):
res = await layout.confirm_output(ctx, req.output, req.coin)
progress.report_init()
elif isinstance(req, UiConfirmTotal):
elif isinstance(req, helpers.UiConfirmTotal):
res = await layout.confirm_total(ctx, req.spending, req.fee, req.coin)
progress.report_init()
elif isinstance(req, UiConfirmFeeOverThreshold):
elif isinstance(req, helpers.UiConfirmFeeOverThreshold):
res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin)
progress.report_init()
elif isinstance(req, UiConfirmForeignAddress):
elif isinstance(req, helpers.UiConfirmForeignAddress):
res = await paths.show_path_warning(ctx, req.address_n)
else:
raise TypeError("Invalid signing instruction")

@ -12,6 +12,7 @@ from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.utils import obj_eq
from apps.common.coininfo import CoinInfo
@ -24,6 +25,8 @@ class UiConfirmOutput:
self.output = output
self.coin = coin
__eq__ = obj_eq
class UiConfirmTotal:
def __init__(self, spending: int, fee: int, coin: CoinInfo):
@ -31,17 +34,23 @@ class UiConfirmTotal:
self.fee = fee
self.coin = coin
__eq__ = obj_eq
class UiConfirmFeeOverThreshold:
def __init__(self, fee: int, coin: CoinInfo):
self.fee = fee
self.coin = coin
__eq__ = obj_eq
class UiConfirmForeignAddress:
def __init__(self, address_n: list):
self.address_n = address_n
__eq__ = obj_eq
def confirm_output(output: TxOutputType, coin: CoinInfo):
return (yield UiConfirmOutput(output, coin))

@ -1,35 +1,30 @@
from micropython import const
from trezor import utils
from trezor.crypto import base58, bip32, cashaddr, der
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import blake256, sha256
from trezor.messages import OutputScriptType
from trezor.messages import FailureType, InputScriptType, OutputScriptType
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.messages.TxOutputType import TxOutputType
from trezor.messages.TxRequest import TxRequest
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.utils import HashWriter
from apps.common import address_type, coins
from apps.common.coininfo import CoinInfo
from apps.common.writers import empty_bytearray
from apps.wallet.sign_tx import progress
from apps.wallet.sign_tx.addresses import *
from apps.wallet.sign_tx.decred_prefix_hasher import (
DECRED_SERIALIZE_NO_WITNESS,
DECRED_SERIALIZE_WITNESS_SIGNING,
DECRED_SIGHASHALL,
DecredPrefixHasher,
)
from apps.wallet.sign_tx.helpers import *
from apps.wallet.sign_tx.multisig import *
from apps.wallet.sign_tx.scripts import *
from apps.wallet.sign_tx.segwit_bip143 import Bip143, Bip143Error # noqa:F401
from apps.wallet.sign_tx.tx_weight_calculator import *
from apps.wallet.sign_tx.writers import *
from apps.wallet.sign_tx.zcash import ( # noqa:F401
OVERWINTERED,
ZcashError,
Zip143,
Zip243,
from apps.common import address_type, coininfo, coins, seed
from apps.wallet.sign_tx import (
addresses,
decred,
helpers,
multisig,
progress,
scripts,
segwit_bip143,
tx_weight,
writers,
zcash,
)
# the number of bip32 levels used in a wallet (chain and address)
@ -58,32 +53,32 @@ class SigningError(ValueError):
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
async def check_tx_fee(tx: SignTx, keychain: seed.Keychain):
coin = coins.by_name(tx.coin_name)
# h_first is used to make sure the inputs and outputs streamed in Phase 1
# are the same as in Phase 2. it is thus not required to fully hash the
# tx, as the SignTx info is streamed only once
h_first = HashWriter(sha256()) # not a real tx hash
h_first = utils.HashWriter(sha256()) # not a real tx hash
if coin.decred:
hash143 = DecredPrefixHasher(tx) # pseudo bip143 prefix hashing
hash143 = decred.DecredPrefixHasher(tx) # pseudo BIP-0143 prefix hashing
tx_ser = TxRequestSerializedType()
elif tx.overwintered:
if tx.version == 3:
hash143 = Zip143() # ZIP-0143 transaction hashing
hash143 = zcash.Zip143() # ZIP-0143 transaction hashing
elif tx.version == 4:
hash143 = Zip243() # ZIP-0243 transaction hashing
hash143 = zcash.Zip243() # ZIP-0243 transaction hashing
else:
raise SigningError(
FailureType.DataError,
"Unsupported version for overwintered transaction",
)
else:
hash143 = Bip143() # BIP-0143 transaction hashing
hash143 = segwit_bip143.Bip143() # BIP-0143 transaction hashing
multifp = MultisigFingerprint() # control checksum of multisig inputs
weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count)
multifp = multisig.MultisigFingerprint() # control checksum of multisig inputs
weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
total_in = 0 # sum of input amounts
segwit_in = 0 # sum of segwit input amounts
@ -100,15 +95,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
for i in range(tx.inputs_count):
progress.advance()
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
txi = await helpers.request_tx_input(tx_req, i)
wallet_path = input_extract_wallet_path(txi, wallet_path)
write_tx_input_check(h_first, txi)
writers.write_tx_input_check(h_first, txi)
weight.add_input(txi)
hash143.add_prevouts(txi) # all inputs are included (non-segwit as well)
hash143.add_sequence(txi)
if not validate_full_path(txi.address_n, coin, txi.script_type):
await confirm_foreign_address(txi.address_n)
if not addresses.validate_full_path(txi.address_n, coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n)
if txi.multisig:
multifp.add(txi.multisig)
@ -149,10 +144,10 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
raise SigningError(FailureType.DataError, "Wrong input script type")
if coin.decred:
w_txi = empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash))
w_txi = writers.empty_bytearray(8 if i == 0 else 0 + 9 + len(txi.prev_hash))
if i == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(coin, tx))
write_tx_input_decred(w_txi, txi)
writers.write_bytes(w_txi, get_tx_header(coin, tx))
writers.write_tx_input_decred(w_txi, txi)
tx_ser.serialized_tx = w_txi
tx_req.serialized = tx_ser
@ -161,15 +156,15 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT
txo = await request_tx_output(tx_req, o)
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
weight.add_output(txo_bin.script_pubkey)
if change_out == 0 and is_change(txo, wallet_path, segwit_in, multifp):
if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp):
# output is change and does not need confirmation
change_out = txo.amount
elif not await confirm_output(txo, coin):
elif not await helpers.confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled, "Output cancelled")
if coin.decred:
@ -180,15 +175,17 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
)
txo_bin.decred_script_version = txo.decred_script_version
w_txo_bin = empty_bytearray(4 + 8 + 2 + 4 + len(txo_bin.script_pubkey))
w_txo_bin = writers.empty_bytearray(
4 + 8 + 2 + 4 + len(txo_bin.script_pubkey)
)
if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin)
writers.write_varint(w_txo_bin, tx.outputs_count)
writers.write_tx_output(w_txo_bin, txo_bin)
tx_ser.serialized_tx = w_txo_bin
tx_req.serialized = tx_ser
hash143.set_last_output_bytes(w_txo_bin)
write_tx_output(h_first, txo_bin)
writers.write_tx_output(h_first, txo_bin)
hash143.add_output(txo_bin)
total_out += txo_bin.amount
@ -198,10 +195,10 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
# fee > (coin.maxfee per byte * tx size)
if fee > (coin.maxfee_kb / 1000) * (weight.get_total() / 4):
if not await confirm_feeoverthreshold(fee, coin):
if not await helpers.confirm_feeoverthreshold(fee, coin):
raise SigningError(FailureType.ActionCancelled, "Signing cancelled")
if not await confirm_total(total_in - change_out, fee, coin):
if not await helpers.confirm_total(total_in - change_out, fee, coin):
raise SigningError(FailureType.ActionCancelled, "Total cancelled")
if coin.decred:
@ -210,14 +207,16 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode):
return h_first, hash143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root: bip32.HDNode):
tx = sanitize_sign_tx(tx)
async def sign_tx(tx: SignTx, keychain: seed.Keychain):
tx = helpers.sanitize_sign_tx(tx)
progress.init(tx.inputs_count, tx.outputs_count)
# Phase 1
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(tx, root)
h_first, hash143, segwit, authorized_in, wallet_path = await check_tx_fee(
tx, keychain
)
# Phase 2
# - sign inputs
@ -242,34 +241,30 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
if segwit[i_sign]:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
txi_sign = await helpers.request_tx_input(tx_req, i_sign)
is_segwit = (
txi_sign.script_type == InputScriptType.SPENDWITNESS
or txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit:
if not input_is_segwit(txi_sign):
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(coin, txi_sign, key_sign_pub)
w_txi = empty_bytearray(
w_txi = writers.empty_bytearray(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(coin, tx, True))
write_tx_input(w_txi, txi_sign)
writers.write_bytes(w_txi, get_tx_header(coin, tx, True))
writers.write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi
tx_req.serialized = tx_ser
elif coin.force_bip143 or tx.overwintered:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
txi_sign = await helpers.request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path)
is_bip143 = (
@ -282,19 +277,19 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
)
authorized_in -= txi_sign.amount
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin,
tx,
txi_sign,
ecdsa_hash_pubkey(key_sign_pub, coin),
addresses.ecdsa_hash_pubkey(key_sign_pub, coin),
get_hash_type(coin),
)
# if multisig, check if singing with a key that is included in multisig
# if multisig, check if signing with a key that is included in multisig
if txi_sign.multisig:
multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
signature = ecdsa_sign(key_sign, hash143_hash)
tx_ser.signature_index = i_sign
@ -304,56 +299,59 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = empty_bytearray(
w_txi_sign = writers.empty_bytearray(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign)
writers.write_bytes(w_txi_sign, get_tx_header(coin, tx))
writers.write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
tx_req.serialized = tx_ser
elif coin.decred:
txi_sign = await request_tx_input(tx_req, i_sign)
txi_sign = await helpers.request_tx_input(tx_req, i_sign)
input_check_wallet_path(txi_sign, wallet_path)
key_sign = node_derive(root, txi_sign.address_n)
key_sign = keychain.derive(txi_sign.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
prev_pkscript = output_script_multisig(
multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m
prev_pkscript = scripts.output_script_multisig(
multisig.multisig_get_pubkeys(txi_sign.multisig),
txi_sign.multisig.m,
)
elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
prev_pkscript = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub, coin)
prev_pkscript = scripts.output_script_p2pkh(
addresses.ecdsa_hash_pubkey(key_sign_pub, coin)
)
else:
raise ValueError("Unknown input script type")
h_witness = HashWriter(blake256())
write_uint32(h_witness, tx.version | DECRED_SERIALIZE_WITNESS_SIGNING)
write_varint(h_witness, tx.inputs_count)
h_witness = utils.HashWriter(blake256())
writers.write_uint32(
h_witness, tx.version | decred.DECRED_SERIALIZE_WITNESS_SIGNING
)
writers.write_varint(h_witness, tx.inputs_count)
for ii in range(tx.inputs_count):
if ii == i_sign:
write_varint(h_witness, len(prev_pkscript))
write_bytes(h_witness, prev_pkscript)
writers.write_varint(h_witness, len(prev_pkscript))
writers.write_bytes(h_witness, prev_pkscript)
else:
write_varint(h_witness, 0)
writers.write_varint(h_witness, 0)
witness_hash = get_tx_hash(
witness_hash = writers.get_tx_hash(
h_witness, double=coin.sign_hash_double, reverse=False
)
h_sign = HashWriter(blake256())
write_uint32(h_sign, DECRED_SIGHASHALL)
write_bytes(h_sign, prefix_hash)
write_bytes(h_sign, witness_hash)
h_sign = utils.HashWriter(blake256())
writers.write_uint32(h_sign, decred.DECRED_SIGHASHALL)
writers.write_bytes(h_sign, prefix_hash)
writers.write_bytes(h_sign, witness_hash)
sig_hash = get_tx_hash(h_sign, double=coin.sign_hash_double)
sig_hash = writers.get_tx_hash(h_sign, double=coin.sign_hash_double)
signature = ecdsa_sign(key_sign, sig_hash)
tx_ser.signature_index = i_sign
tx_ser.signature = signature
@ -362,61 +360,62 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = empty_bytearray(
w_txi_sign = writers.empty_bytearray(
8 + 4 + len(hash143.get_last_output_bytes())
if i_sign == 0
else 0 + 16 + 4 + len(txi_sign.script_sig)
)
if i_sign == 0:
write_bytes(w_txi_sign, hash143.get_last_output_bytes())
write_uint32(w_txi_sign, tx.lock_time)
write_uint32(w_txi_sign, tx.expiry)
write_varint(w_txi_sign, tx.inputs_count)
writers.write_bytes(w_txi_sign, hash143.get_last_output_bytes())
writers.write_uint32(w_txi_sign, tx.lock_time)
writers.write_uint32(w_txi_sign, tx.expiry)
writers.write_varint(w_txi_sign, tx.inputs_count)
write_tx_input_decred_witness(w_txi_sign, txi_sign)
writers.write_tx_input_decred_witness(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
tx_req.serialized = tx_ser
else:
# hash of what we are signing with this input
h_sign = HashWriter(sha256())
h_sign = utils.HashWriter(sha256())
# same as h_first, checked before signing the digest
h_second = HashWriter(sha256())
h_second = utils.HashWriter(sha256())
if tx.overwintered:
write_uint32(
h_sign, tx.version | OVERWINTERED
writers.write_uint32(
h_sign, tx.version | zcash.OVERWINTERED
) # nVersion | fOverwintered
write_uint32(h_sign, tx.version_group_id) # nVersionGroupId
writers.write_uint32(h_sign, tx.version_group_id) # nVersionGroupId
else:
write_uint32(h_sign, tx.version) # nVersion
writers.write_uint32(h_sign, tx.version) # nVersion
if tx.timestamp:
write_uint32(h_sign, tx.timestamp)
writers.write_uint32(h_sign, tx.timestamp)
write_varint(h_sign, tx.inputs_count)
writers.write_varint(h_sign, tx.inputs_count)
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await request_tx_input(tx_req, i)
txi = await helpers.request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path)
write_tx_input_check(h_second, txi)
writers.write_tx_input_check(h_second, txi)
if i == i_sign:
txi_sign = txi
key_sign = node_derive(root, txi.address_n)
key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
# for the signing process the script_sig is equal
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
txi_sign.script_sig = output_script_multisig(
multisig_get_pubkeys(txi_sign.multisig), txi_sign.multisig.m
txi_sign.script_sig = scripts.output_script_multisig(
multisig.multisig_get_pubkeys(txi_sign.multisig),
txi_sign.multisig.m,
)
elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
txi_sign.script_sig = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub, coin)
txi_sign.script_sig = scripts.output_script_p2pkh(
addresses.ecdsa_hash_pubkey(key_sign_pub, coin)
)
if coin.bip115:
txi_sign.script_sig += script_replay_protection_bip115(
txi_sign.script_sig += scripts.script_replay_protection_bip115(
txi_sign.prev_block_hash_bip115,
txi_sign.prev_block_height_bip115,
)
@ -426,38 +425,38 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
)
else:
txi.script_sig = bytes()
write_tx_input(h_sign, txi)
writers.write_tx_input(h_sign, txi)
write_varint(h_sign, tx.outputs_count)
writers.write_varint(h_sign, tx.outputs_count)
for o in range(tx.outputs_count):
# STAGE_REQUEST_4_OUTPUT
txo = await request_tx_output(tx_req, o)
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_second, txo_bin)
write_tx_output(h_sign, txo_bin)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
writers.write_tx_output(h_second, txo_bin)
writers.write_tx_output(h_sign, txo_bin)
write_uint32(h_sign, tx.lock_time)
writers.write_uint32(h_sign, tx.lock_time)
if tx.overwintered:
write_uint32(h_sign, tx.expiry) # expiryHeight
write_varint(h_sign, 0) # nJoinSplit
writers.write_uint32(h_sign, tx.expiry) # expiryHeight
writers.write_varint(h_sign, 0) # nJoinSplit
write_uint32(h_sign, get_hash_type(coin))
writers.write_uint32(h_sign, get_hash_type(coin))
# check the control digests
if get_tx_hash(h_first, False) != get_tx_hash(h_second):
if writers.get_tx_hash(h_first, False) != writers.get_tx_hash(h_second):
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
# if multisig, check if singing with a key that is included in multisig
# if multisig, check if signing with a key that is included in multisig
if txi_sign.multisig:
multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub)
# compute the signature from the tx digest
signature = ecdsa_sign(
key_sign, get_tx_hash(h_sign, double=coin.sign_hash_double)
key_sign, writers.get_tx_hash(h_sign, double=coin.sign_hash_double)
)
tx_ser.signature_index = i_sign
tx_ser.signature = signature
@ -466,31 +465,31 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
txi_sign.script_sig = input_derive_script(
coin, txi_sign, key_sign_pub, signature
)
w_txi_sign = empty_bytearray(
w_txi_sign = writers.empty_bytearray(
5 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4
)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi_sign, get_tx_header(coin, tx))
write_tx_input(w_txi_sign, txi_sign)
writers.write_bytes(w_txi_sign, get_tx_header(coin, tx))
writers.write_tx_input(w_txi_sign, txi_sign)
tx_ser.serialized_tx = w_txi_sign
tx_req.serialized = tx_ser
if coin.decred:
return await request_tx_finish(tx_req)
return await helpers.request_tx_finish(tx_req)
for o in range(tx.outputs_count):
progress.advance()
# STAGE_REQUEST_5_OUTPUT
txo = await request_tx_output(tx_req, o)
txo = await helpers.request_tx_output(tx_req, o)
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
txo_bin.script_pubkey = output_derive_script(txo, coin, keychain)
# serialize output
w_txo_bin = empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
w_txo_bin = writers.empty_bytearray(5 + 8 + 5 + len(txo_bin.script_pubkey) + 4)
if o == 0: # serializing first output => prepend outputs count
write_varint(w_txo_bin, tx.outputs_count)
write_tx_output(w_txo_bin, txo_bin)
writers.write_varint(w_txo_bin, tx.outputs_count)
writers.write_tx_output(w_txo_bin, txo_bin)
tx_ser.signature_index = None
tx_ser.signature = None
@ -504,38 +503,38 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
progress.advance()
if segwit[i]:
# STAGE_REQUEST_SEGWIT_WITNESS
txi = await request_tx_input(tx_req, i)
txi = await helpers.request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path)
is_segwit = (
txi.script_type == InputScriptType.SPENDWITNESS
or txi.script_type == InputScriptType.SPENDP2SHWITNESS
)
if not is_segwit or txi.amount > authorized_in:
if not input_is_segwit(txi) or txi.amount > authorized_in:
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
authorized_in -= txi.amount
key_sign = node_derive(root, txi.address_n)
key_sign = keychain.derive(txi.address_n, coin.curve_name)
key_sign_pub = key_sign.public_key()
hash143_hash = hash143.preimage_hash(
coin,
tx,
txi,
ecdsa_hash_pubkey(key_sign_pub, coin),
addresses.ecdsa_hash_pubkey(key_sign_pub, coin),
get_hash_type(coin),
)
signature = ecdsa_sign(key_sign, hash143_hash)
if txi.multisig:
# find out place of our signature based on the pubkey
signature_index = multisig_pubkey_index(txi.multisig, key_sign_pub)
witness = witness_p2wsh(
signature_index = multisig.multisig_pubkey_index(
txi.multisig, key_sign_pub
)
witness = scripts.witness_p2wsh(
txi.multisig, signature, signature_index, get_hash_type(coin)
)
else:
witness = witness_p2wpkh(signature, key_sign_pub, get_hash_type(coin))
witness = scripts.witness_p2wpkh(
signature, key_sign_pub, get_hash_type(coin)
)
tx_ser.serialized_tx = witness
tx_ser.signature_index = i
@ -547,66 +546,68 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode):
tx_req.serialized = tx_ser
write_uint32(tx_ser.serialized_tx, tx.lock_time)
writers.write_uint32(tx_ser.serialized_tx, tx.lock_time)
if tx.overwintered:
if tx.version == 3:
write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
writers.write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
writers.write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
elif tx.version == 4:
write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
write_uint64(tx_ser.serialized_tx, 0) # valueBalance
write_varint(tx_ser.serialized_tx, 0) # nShieldedSpend
write_varint(tx_ser.serialized_tx, 0) # nShieldedOutput
write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
writers.write_uint32(tx_ser.serialized_tx, tx.expiry) # expiryHeight
writers.write_uint64(tx_ser.serialized_tx, 0) # valueBalance
writers.write_varint(tx_ser.serialized_tx, 0) # nShieldedSpend
writers.write_varint(tx_ser.serialized_tx, 0) # nShieldedOutput
writers.write_varint(tx_ser.serialized_tx, 0) # nJoinSplit
else:
raise SigningError(
FailureType.DataError,
"Unsupported version for overwintered transaction",
)
await request_tx_finish(tx_req)
await helpers.request_tx_finish(tx_req)
async def get_prevtx_output_value(
coin: CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int
coin: coininfo.CoinInfo, tx_req: TxRequest, prev_hash: bytes, prev_index: int
) -> int:
total_out = 0 # sum of output amounts
# STAGE_REQUEST_2_PREV_META
tx = await request_tx_meta(tx_req, prev_hash)
tx = await helpers.request_tx_meta(tx_req, prev_hash)
if coin.decred:
txh = HashWriter(blake256())
txh = utils.HashWriter(blake256())
else:
txh = HashWriter(sha256())
txh = utils.HashWriter(sha256())
if tx.overwintered:
write_uint32(txh, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(txh, tx.version_group_id) # nVersionGroupId
writers.write_uint32(
txh, tx.version | zcash.OVERWINTERED
) # nVersion | fOverwintered
writers.write_uint32(txh, tx.version_group_id) # nVersionGroupId
elif coin.decred:
write_uint32(txh, tx.version | DECRED_SERIALIZE_NO_WITNESS)
writers.write_uint32(txh, tx.version | decred.DECRED_SERIALIZE_NO_WITNESS)
else:
write_uint32(txh, tx.version) # nVersion
writers.write_uint32(txh, tx.version) # nVersion
if tx.timestamp:
write_uint32(txh, tx.timestamp)
writers.write_uint32(txh, tx.timestamp)
write_varint(txh, tx.inputs_cnt)
writers.write_varint(txh, tx.inputs_cnt)
for i in range(tx.inputs_cnt):
# STAGE_REQUEST_2_PREV_INPUT
txi = await request_tx_input(tx_req, i, prev_hash)
txi = await helpers.request_tx_input(tx_req, i, prev_hash)
if coin.decred:
write_tx_input_decred(txh, txi)
writers.write_tx_input_decred(txh, txi)
else:
write_tx_input(txh, txi)
writers.write_tx_input(txh, txi)
write_varint(txh, tx.outputs_cnt)
writers.write_varint(txh, tx.outputs_cnt)
for o in range(tx.outputs_cnt):
# STAGE_REQUEST_2_PREV_OUTPUT
txo_bin = await request_tx_output(tx_req, o, prev_hash)
write_tx_output(txh, txo_bin)
txo_bin = await helpers.request_tx_output(tx_req, o, prev_hash)
writers.write_tx_output(txh, txo_bin)
if o == prev_index:
total_out += txo_bin.amount
if (
@ -619,19 +620,22 @@ async def get_prevtx_output_value(
"Cannot use utxo that has script_version != 0",
)
write_uint32(txh, tx.lock_time)
writers.write_uint32(txh, tx.lock_time)
if tx.overwintered or coin.decred:
write_uint32(txh, tx.expiry)
writers.write_uint32(txh, tx.expiry)
ofs = 0
while ofs < tx.extra_data_len:
size = min(1024, tx.extra_data_len - ofs)
data = await request_tx_extra_data(tx_req, ofs, size, prev_hash)
write_bytes(txh, data)
data = await helpers.request_tx_extra_data(tx_req, ofs, size, prev_hash)
writers.write_bytes(txh, data)
ofs += len(data)
if get_tx_hash(txh, double=coin.sign_hash_double, reverse=True) != prev_hash:
if (
writers.get_tx_hash(txh, double=coin.sign_hash_double, reverse=True)
!= prev_hash
):
raise SigningError(FailureType.ProcessError, "Encountered invalid prev_hash")
return total_out
@ -641,7 +645,7 @@ async def get_prevtx_output_value(
# ===
def get_hash_type(coin: CoinInfo) -> int:
def get_hash_type(coin: coininfo.CoinInfo) -> int:
SIGHASH_FORKID = const(0x40)
SIGHASH_ALL = const(0x01)
hashtype = SIGHASH_ALL
@ -650,19 +654,21 @@ def get_hash_type(coin: CoinInfo) -> int:
return hashtype
def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False):
def get_tx_header(coin: coininfo.CoinInfo, tx: SignTx, segwit: bool = False):
w_txi = bytearray()
if tx.overwintered:
write_uint32(w_txi, tx.version | OVERWINTERED) # nVersion | fOverwintered
write_uint32(w_txi, tx.version_group_id) # nVersionGroupId
writers.write_uint32(
w_txi, tx.version | zcash.OVERWINTERED
) # nVersion | fOverwintered
writers.write_uint32(w_txi, tx.version_group_id) # nVersionGroupId
else:
write_uint32(w_txi, tx.version) # nVersion
writers.write_uint32(w_txi, tx.version) # nVersion
if tx.timestamp:
write_uint32(w_txi, tx.timestamp)
writers.write_uint32(w_txi, tx.timestamp)
if segwit:
write_varint(w_txi, 0x00) # segwit witness marker
write_varint(w_txi, 0x01) # segwit witness flag
write_varint(w_txi, tx.inputs_count)
writers.write_varint(w_txi, 0x00) # segwit witness marker
writers.write_varint(w_txi, 0x01) # segwit witness flag
writers.write_varint(w_txi, tx.inputs_count)
return w_txi
@ -670,7 +676,9 @@ def get_tx_header(coin: CoinInfo, tx: SignTx, segwit: bool = False):
# ===
def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) -> bytes:
def output_derive_script(
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN:
# op_return output
@ -678,21 +686,21 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
raise SigningError(
FailureType.DataError, "OP_RETURN output with non-zero amount"
)
return output_script_paytoopreturn(o.op_return_data)
return scripts.output_script_paytoopreturn(o.op_return_data)
if o.address_n:
# change output
if o.address:
raise SigningError(FailureType.DataError, "Address in change output")
o.address = get_address_for_change(o, coin, root)
o.address = get_address_for_change(o, coin, keychain)
else:
if not o.address:
raise SigningError(FailureType.DataError, "Missing address")
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog)
witprog = addresses.decode_bech32_address(coin.bech32_prefix, o.address)
return scripts.output_script_native_p2wpkh_or_p2wsh(witprog)
if coin.cashaddr_prefix is not None and o.address.startswith(
coin.cashaddr_prefix + ":"
@ -712,9 +720,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
if address_type.check(coin.address_type, raw_address):
# p2pkh
pubkeyhash = address_type.strip(coin.address_type, raw_address)
script = output_script_p2pkh(pubkeyhash)
script = scripts.output_script_p2pkh(pubkeyhash)
if coin.bip115:
script += script_replay_protection_bip115(
script += scripts.script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script
@ -722,9 +730,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
elif address_type.check(coin.address_type_p2sh, raw_address):
# p2sh
scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
script = output_script_p2sh(scripthash)
script = scripts.output_script_p2sh(scripthash)
if coin.bip115:
script += script_replay_protection_bip115(
script += scripts.script_replay_protection_bip115(
o.block_hash_bip115, o.block_height_bip115
)
return script
@ -732,7 +740,9 @@ def output_derive_script(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode) ->
raise SigningError(FailureType.DataError, "Invalid address type")
def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
def get_address_for_change(
o: TxOutputType, coin: coininfo.CoinInfo, keychain: seed.Keychain
):
if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS
elif o.script_type == OutputScriptType.PAYTOMULTISIG:
@ -743,17 +753,19 @@ def get_address_for_change(o: TxOutputType, coin: CoinInfo, root: bip32.HDNode):
input_script_type = InputScriptType.SPENDP2SHWITNESS
else:
raise SigningError(FailureType.DataError, "Invalid script type")
return get_address(
input_script_type, coin, node_derive(root, o.address_n), o.multisig
)
node = keychain.derive(o.address_n, coin.curve_name)
return addresses.get_address(input_script_type, coin, node, o.multisig)
def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool:
is_segwit = (
o.script_type == OutputScriptType.PAYTOWITNESS
or o.script_type == OutputScriptType.PAYTOP2SHWITNESS
)
if is_segwit and o.amount > segwit_in:
def output_is_change(
o: TxOutputType,
wallet_path: list,
segwit_in: int,
multifp: multisig.MultisigFingerprint,
) -> bool:
if o.multisig and not multifp.matches(o.multisig):
return False
if output_is_segwit(o) and o.amount > segwit_in:
# if the output is segwit, make sure it doesn't spend more than what the
# segwit inputs paid. this is to prevent user being tricked into
# creating ANYONECANSPEND outputs before full segwit activation.
@ -766,38 +778,49 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool
)
def output_is_segwit(o: TxOutputType) -> bool:
return (
o.script_type == OutputScriptType.PAYTOWITNESS
or o.script_type == OutputScriptType.PAYTOP2SHWITNESS
)
# Tx Inputs
# ===
def input_derive_script(
coin: CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None
coin: coininfo.CoinInfo, i: TxInputType, pubkey: bytes, signature: bytes = None
) -> bytes:
if i.script_type == InputScriptType.SPENDADDRESS:
# p2pkh or p2sh
return input_script_p2pkh_or_p2sh(pubkey, signature, get_hash_type(coin))
return scripts.input_script_p2pkh_or_p2sh(
pubkey, signature, get_hash_type(coin)
)
if i.script_type == InputScriptType.SPENDP2SHWITNESS:
# p2wpkh or p2wsh using p2sh
if i.multisig:
# p2wsh in p2sh
pubkeys = multisig_get_pubkeys(i.multisig)
witness_script = output_script_multisig(pubkeys, i.multisig.m)
pubkeys = multisig.multisig_get_pubkeys(i.multisig)
witness_script = scripts.output_script_multisig(pubkeys, i.multisig.m)
witness_script_hash = sha256(witness_script).digest()
return input_script_p2wsh_in_p2sh(witness_script_hash)
return scripts.input_script_p2wsh_in_p2sh(witness_script_hash)
# p2wpkh in p2sh
return input_script_p2wpkh_in_p2sh(ecdsa_hash_pubkey(pubkey, coin))
return scripts.input_script_p2wpkh_in_p2sh(
addresses.ecdsa_hash_pubkey(pubkey, coin)
)
elif i.script_type == InputScriptType.SPENDWITNESS:
# native p2wpkh or p2wsh
return input_script_native_p2wpkh_or_p2wsh()
return scripts.input_script_native_p2wpkh_or_p2wsh()
elif i.script_type == InputScriptType.SPENDMULTISIG:
# p2sh multisig
signature_index = multisig_pubkey_index(i.multisig, pubkey)
return input_script_multisig(
signature_index = multisig.multisig_pubkey_index(i.multisig, pubkey)
return scripts.input_script_multisig(
i.multisig, signature, signature_index, get_hash_type(coin), coin
)
@ -805,6 +828,13 @@ def input_derive_script(
raise SigningError(FailureType.ProcessError, "Invalid script type")
def input_is_segwit(i: TxInputType) -> bool:
return (
i.script_type == InputScriptType.SPENDWITNESS
or i.script_type == InputScriptType.SPENDP2SHWITNESS
)
def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list:
if wallet_path is None:
return None # there was a mismatch in previous inputs
@ -828,22 +858,7 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
)
def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode:
node = root.clone()
node.derive_path(address_n)
return node
def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65]))
return sigder
def is_change(
txo: TxOutputType, wallet_path: list, segwit_in: int, multifp: MultisigFingerprint
) -> bool:
if txo.multisig:
if not multifp.matches(txo.multisig):
return False
return output_is_change(txo, wallet_path, segwit_in)

@ -3,7 +3,8 @@ from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputBinType import TxOutputBinType
from trezor.utils import ensure
from apps.common.writers import (
from apps.common.writers import ( # noqa: F401
empty_bytearray,
write_bytes,
write_bytes_reversed,
write_uint8,

@ -3,12 +3,25 @@ from trezor import log, loop, messages, utils, workflow
from trezor.wire import codec_v1
from trezor.wire.errors import *
from apps.common import seed
workflow_handlers = {}
def add(mtype, pkgname, modname, *args):
def add(mtype, pkgname, modname, namespace=None):
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
register(mtype, protobuf_workflow, import_workflow, pkgname, modname, *args)
if namespace is not None:
register(
mtype,
protobuf_workflow,
keychain_workflow,
namespace,
import_workflow,
pkgname,
modname,
)
else:
register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
def register(mtype, handler, *args):
@ -133,10 +146,12 @@ async def session_handler(iface, sid):
continue
except Error as exc:
# we log wire.Error as warning, not as exception
log.warning(__name__, "failure: %s", exc.message)
if __debug__:
log.warning(__name__, "failure: %s", exc.message)
except Exception as exc:
# sessions are never closed by raised exceptions
log.exception(__name__, exc)
if __debug__:
log.exception(__name__, exc)
# read new message in next iteration
reader = None
@ -155,7 +170,7 @@ async def protobuf_workflow(ctx, reader, handler, *args):
# respond with specific code and message
await ctx.write(Failure(code=exc.code, message=exc.message))
raise
except Exception: # as exc:
except Exception:
# respond with a generic code and message
await ctx.write(
Failure(code=FailureType.FirmwareError, message="Firmware error")
@ -166,6 +181,15 @@ async def protobuf_workflow(ctx, reader, handler, *args):
await ctx.write(res)
async def keychain_workflow(ctx, req, namespace, handler, *args):
keychain = await seed.get_keychain(ctx, namespace)
args += (keychain,)
try:
return await handler(ctx, req, *args)
finally:
keychain.__del__()
def import_workflow(ctx, req, pkgname, modname, *args):
modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0)

@ -8,6 +8,7 @@ from apps.cardano.address import (
validate_full_path,
derive_address_and_node
)
from apps.cardano.seed import Keychain
from trezor.crypto import bip32
@ -16,6 +17,9 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all"
passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
addresses = [
"Ae2tdPwUPEZ98eHFwxSsPBDz73amioKpr58Vw85mP1tMkzq8siaftiejJ3j",
@ -25,7 +29,7 @@ class TestCardanoAddress(unittest.TestCase):
for i, expected in enumerate(addresses):
# 44'/1815'/0'/0/i'
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
self.assertEqual(expected, address)
nodes = [
@ -50,7 +54,7 @@ class TestCardanoAddress(unittest.TestCase):
]
for i, (priv, ext, pub, chain) in enumerate(nodes):
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
_, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000 + i])
self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)
@ -60,6 +64,9 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all"
passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
addresses = [
"Ae2tdPwUPEZ5YUb8sM3eS8JqKgrRLzhiu71crfuH2MFtqaYr5ACNRdsswsZ",
@ -69,7 +76,7 @@ class TestCardanoAddress(unittest.TestCase):
for i, expected in enumerate(addresses):
# 44'/1815'/0'/0/i
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
self.assertEqual(address, expected)
nodes = [
@ -94,7 +101,7 @@ class TestCardanoAddress(unittest.TestCase):
]
for i, (priv, ext, pub, chain) in enumerate(nodes):
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
_, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, i])
self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)
@ -105,9 +112,12 @@ class TestCardanoAddress(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all"
passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
# 44'/1815'
address, _ = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815])
address, _ = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815])
self.assertEqual(address, "Ae2tdPwUPEZ2FGHX3yCKPSbSgyuuTYgMxNq652zKopxT4TuWvEd8Utd92w3")
priv, ext, pub, chain = (
@ -117,7 +127,7 @@ class TestCardanoAddress(unittest.TestCase):
b"02ac67c59a8b0264724a635774ca2c242afa10d7ab70e2bf0a8f7d4bb10f1f7a"
)
_, n = derive_address_and_node(node, [0x80000000 | 44, 0x80000000 | 1815])
_, n = derive_address_and_node(keychain, [0x80000000 | 44, 0x80000000 | 1815])
self.assertEqual(hexlify(n.private_key()), priv)
self.assertEqual(hexlify(n.private_key_ext()), ext)
self.assertEqual(hexlify(seed.remove_ed25519_prefix(n.public_key())), pub)

@ -1,5 +1,6 @@
from common import *
from apps.cardano.seed import Keychain
from apps.cardano.get_public_key import _get_public_key
from trezor.crypto import bip32
from ubinascii import hexlify
@ -10,6 +11,9 @@ class TestCardanoGetPublicKey(unittest.TestCase):
mnemonic = "all all all all all all all all all all all all"
passphrase = ""
node = bip32.from_mnemonic_cardano(mnemonic, passphrase)
node.derive_cardano(0x80000000 | 44)
node.derive_cardano(0x80000000 | 1815)
keychain = Keychain([0x80000000 | 44, 0x80000000 | 1815], node)
derivation_paths = [
[0x80000000 | 44, 0x80000000 | 1815, 0x80000000, 0, 0x80000000],
@ -40,7 +44,7 @@ class TestCardanoGetPublicKey(unittest.TestCase):
]
for index, derivation_path in enumerate(derivation_paths):
key = _get_public_key(node, derivation_path)
key = _get_public_key(keychain, derivation_path)
self.assertEqual(hexlify(key.node.public_key), public_keys[index])
self.assertEqual(hexlify(key.node.chain_code), chain_codes[index])

@ -45,7 +45,7 @@ class TestEthereumLayout(unittest.TestCase):
text = format_ethereum_amount(1000000000000000000, None, 61)
self.assertEqual(text, '1 ETC')
text = format_ethereum_amount(1000000000000000000, None, 31)
self.assertEqual(text, '1 tRSK')
self.assertEqual(text, '1 tRBTC')
text = format_ethereum_amount(1000000000000000001, None, 1)
self.assertEqual(text, '1.000000000000000001 ETH')
@ -54,7 +54,7 @@ class TestEthereumLayout(unittest.TestCase):
text = format_ethereum_amount(10000000000000000001, None, 61)
self.assertEqual(text, '10.000000000000000001 ETC')
text = format_ethereum_amount(1000000000000000001, None, 31)
self.assertEqual(text, '1.000000000000000001 tRSK')
self.assertEqual(text, '1.000000000000000001 tRBTC')
# unknown chain
text = format_ethereum_amount(1, None, 9999)

@ -1,10 +1,19 @@
from common import *
from trezor.crypto import bip32, bip39
from trezor.utils import HashWriter
from apps.wallet.sign_tx.addresses import validate_full_path, validate_path_for_bitcoin_public_key
from apps.common.paths import HARDENED
from apps.common import coins
from apps.wallet.sign_tx.addresses import *
from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.writers import *
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddress(unittest.TestCase):

@ -1,10 +1,17 @@
from common import *
from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.addresses import *
from apps.common import coins
from trezor.crypto import bip32, bip39
def node_derive(root, path):
node = root.clone()
node.derive_path(path)
return node
class TestAddressGRS(unittest.TestCase):
# pylint: disable=C0301

@ -1,6 +1,7 @@
from common import *
from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.common import coins
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType

@ -1,6 +1,7 @@
from common import *
from apps.wallet.sign_tx.signing import *
from apps.wallet.sign_tx.segwit_bip143 import *
from apps.common import coins
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType

@ -1,7 +1,7 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
@ -15,7 +15,8 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
@ -24,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
def test_send_native_p2wpkh(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s
@ -61,22 +60,22 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmOutput(out2, coin),
helpers.UiConfirmOutput(out2, coin),
True,
signing.UiConfirmTotal(12300000, 11000, coin),
helpers.UiConfirmTotal(12300000, 11000, coin),
True,
# sign tx
@ -113,18 +112,17 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def test_send_native_p2wpkh_change(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s
@ -159,19 +157,19 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmTotal(5000000 + 11000, 11000, coin),
helpers.UiConfirmTotal(5000000 + 11000, 11000, coin),
True,
# sign tx
@ -209,21 +207,13 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal)) or
(isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -1,7 +1,7 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
@ -15,7 +15,8 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/9b5c4859a8a31e69788cb4402812bb28f14ad71cbd8c60b09903478bc56f79a3
class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
@ -24,9 +25,7 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
def test_send_native_p2wpkh(self):
coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType(
# 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja
@ -64,16 +63,16 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmOutput(out2, coin),
helpers.UiConfirmOutput(out2, coin),
True,
signing.UiConfirmTotal(12300000, 11000, coin),
helpers.UiConfirmTotal(12300000, 11000, coin),
True,
# sign tx
@ -110,18 +109,17 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def test_send_native_p2wpkh_change(self):
coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType(
# 84'/1'/0'/0/0" - tgrs1qkvwu9g3k2pdxewfqr7syz89r3gj557l3ued7ja
@ -159,13 +157,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmTotal(5000000 + 11000, 11000, coin),
helpers.UiConfirmTotal(5000000 + 11000, 11000, coin),
True,
# sign tx
@ -203,20 +201,13 @@ class TestSignSegwitTxNativeP2WPKH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -1,7 +1,7 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
@ -15,7 +15,8 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
@ -24,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
def test_send_p2wpkh_in_p2sh(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -64,16 +63,16 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmOutput(out2, coin),
helpers.UiConfirmOutput(out2, coin),
True,
signing.UiConfirmTotal(123445789 + 11000, 11000, coin),
helpers.UiConfirmTotal(123445789 + 11000, 11000, coin),
True,
# sign tx
@ -110,18 +109,17 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def test_send_p2wpkh_in_p2sh_change(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -160,14 +158,14 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None),
serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmTotal(12300000 + 11000, 11000, coin),
helpers.UiConfirmTotal(12300000 + 11000, 11000, coin),
True,
# sign tx
@ -213,9 +211,10 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
@ -224,9 +223,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
def test_send_p2wpkh_in_p2sh_attack_amount(self):
coin = coins.by_name('Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, 'secp256k1')
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX
@ -275,14 +272,14 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None),
serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmTotal(8, 0, coin),
helpers.UiConfirmTotal(8, 0, coin),
True,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None),
@ -322,26 +319,19 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
TxRequest(request_type=TXFINISHED, details=None)
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
i = 0
messages_count = int(len(messages) / 2)
for request, response in chunks(messages, 2):
if i == messages_count - 1: # last message should throw SigningError
self.assertRaises(signing.SigningError, signer.send, request)
else:
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
i += 1
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -1,7 +1,7 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
@ -15,7 +15,8 @@ from trezor.messages import InputScriptType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
# https://groestlsight-test.groestlcoin.org/api/tx/4ce0220004bdfe14e3dd49fd8636bcb770a400c0c9e9bff670b6a13bb8f15c72
class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
@ -24,9 +25,7 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
def test_send_p2wpkh_in_p2sh(self):
coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7
@ -64,16 +63,16 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmOutput(out2, coin),
helpers.UiConfirmOutput(out2, coin),
True,
signing.UiConfirmTotal(123445789 + 11000, 11000, coin),
helpers.UiConfirmTotal(123445789 + 11000, 11000, coin),
True,
# sign tx
@ -110,18 +109,17 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def test_send_p2wpkh_in_p2sh_change(self):
coin = coins.by_name('Groestlcoin Testnet')
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
inp1 = TxInputType(
# 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7
@ -160,14 +158,14 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=1, tx_hash=None),
serialized=None),
TxAck(tx=TransactionType(outputs=[out2])),
signing.UiConfirmTotal(12300000 + 11000, 11000, coin),
helpers.UiConfirmTotal(12300000 + 11000, 11000, coin),
True,
# sign tx
@ -212,20 +210,13 @@ class TestSignSegwitTxP2WPKHInP2SH_GRS(unittest.TestCase):
)),
]
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -14,7 +14,8 @@ from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
class TestSignTxFeeThreshold(unittest.TestCase):
@ -60,7 +61,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -72,11 +73,11 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin_bitcoin),
helpers.UiConfirmOutput(out1, coin_bitcoin),
True,
signing.UiConfirmFeeOverThreshold(100000, coin_bitcoin),
helpers.UiConfirmFeeOverThreshold(100000, coin_bitcoin),
True,
signing.UiConfirmTotal(290000 + 100000, 100000, coin_bitcoin),
helpers.UiConfirmTotal(290000 + 100000, 100000, coin_bitcoin),
True,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
]
@ -84,9 +85,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
def test_under_threshold(self):
coin_bitcoin = coins.by_name('Bitcoin')
@ -127,7 +129,7 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -139,9 +141,9 @@ class TestSignTxFeeThreshold(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin_bitcoin),
helpers.UiConfirmOutput(out1, coin_bitcoin),
True,
signing.UiConfirmTotal(300000 + 90000, 90000, coin_bitcoin),
helpers.UiConfirmTotal(300000 + 90000, 90000, coin_bitcoin),
True,
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
]
@ -149,19 +151,10 @@ class TestSignTxFeeThreshold(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain([[coin_bitcoin.curve_name]], [root])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal)) or
(isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress)) or
(isinstance(a, signing.UiConfirmFeeOverThreshold) and isinstance(b, signing.UiConfirmFeeOverThreshold))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
self.assertEqual(signer.send(request), response)
if __name__ == '__main__':

@ -15,7 +15,8 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
class TestSignTx(unittest.TestCase):
@ -61,7 +62,7 @@ class TestSignTx(unittest.TestCase):
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None)),
TxAck(tx=TransactionType(inputs=[inp1])),
signing.UiConfirmForeignAddress(address_n=inp1.address_n),
helpers.UiConfirmForeignAddress(address_n=inp1.address_n),
True,
TxRequest(request_type=TXMETA, details=TxRequestDetailsType(request_index=None, tx_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882')), serialized=None),
TxAck(tx=ptx1),
@ -73,9 +74,9 @@ class TestSignTx(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin_bitcoin),
helpers.UiConfirmOutput(out1, coin_bitcoin),
True,
signing.UiConfirmTotal(380000 + 10000, 10000, coin_bitcoin),
helpers.UiConfirmTotal(380000 + 10000, 10000, coin_bitcoin),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
@ -96,26 +97,16 @@ class TestSignTx(unittest.TestCase):
]
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
root = bip32.from_seed(seed, 'secp256k1')
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin_bitcoin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
res = signer.send(request)
self.assertEqualEx(res, response)
self.assertEqual(res, response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmForeignAddress) and isinstance(b, signing.UiConfirmForeignAddress)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -1,7 +1,7 @@
from common import *
from trezor.utils import chunks
from trezor.crypto import bip32, bip39
from trezor.crypto import bip39
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
@ -15,7 +15,8 @@ from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType
from apps.common import coins
from apps.wallet.sign_tx import signing
from apps.common.seed import Keychain
from apps.wallet.sign_tx import helpers, signing
class TestSignTx_GRS(unittest.TestCase):
@ -62,9 +63,9 @@ class TestSignTx_GRS(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signing.UiConfirmOutput(out1, coin),
helpers.UiConfirmOutput(out1, coin),
True,
signing.UiConfirmTotal(210016, 192, coin),
helpers.UiConfirmTotal(210016, 192, coin),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
@ -85,22 +86,13 @@ class TestSignTx_GRS(unittest.TestCase):
]
seed = bip39.seed(' '.join(['all'] * 12), '')
root = bip32.from_seed(seed, coin.curve_name)
signer = signing.sign_tx(tx, root)
keychain = Keychain(seed, [[coin.curve_name]])
signer = signing.sign_tx(tx, keychain)
for request, response in chunks(messages, 2):
self.assertEqualEx(signer.send(request), response)
self.assertEqual(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signing.Ui* classes
if ((isinstance(a, signing.UiConfirmOutput) and isinstance(b, signing.UiConfirmOutput)) or
(isinstance(a, signing.UiConfirmTotal) and isinstance(b, signing.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()

@ -5,7 +5,7 @@ from trezor.messages import OutputScriptType
from trezor.crypto import bip32, bip39
from apps.common import coins
from apps.wallet.sign_tx.tx_weight_calculator import *
from apps.wallet.sign_tx.tx_weight import *
from apps.wallet.sign_tx import signing

@ -1 +1 @@
Subproject commit 71528b526020b5c6a95261b07336cff5d68ea66e
Subproject commit 8906ebf92cf754554f231d4341976c2cf5da9a22
Loading…
Cancel
Save