1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 09:28:13 +00:00

paths: validate curve as well

This commit is contained in:
Tomas Susanka 2019-04-05 11:23:06 +02:00
parent ac318cc65b
commit 8aa60e6cfd
45 changed files with 206 additions and 101 deletions

View File

@ -3,7 +3,8 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
SEED_NAMESPACE = [[HARDENED | 44, HARDENED | 1815]] CURVE = "ed25519"
SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815]
def boot(): def boot():

View File

@ -1,7 +1,7 @@
from trezor import log, wire from trezor import log, wire
from trezor.messages.CardanoAddress import CardanoAddress from trezor.messages.CardanoAddress import CardanoAddress
from apps.cardano import seed from apps.cardano import CURVE, seed
from apps.cardano.address import derive_address_and_node, validate_full_path from apps.cardano.address import derive_address_and_node, validate_full_path
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg): async def get_address(ctx, msg):
keychain = await seed.get_keychain(ctx) keychain = await seed.get_keychain(ctx)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
try: try:
address, _ = derive_address_and_node(keychain, msg.address_n) address, _ = derive_address_and_node(keychain, msg.address_n)

View File

@ -4,7 +4,7 @@ from trezor import log, wire
from trezor.messages.CardanoPublicKey import CardanoPublicKey from trezor.messages.CardanoPublicKey import CardanoPublicKey
from trezor.messages.HDNodeType import HDNodeType from trezor.messages.HDNodeType import HDNodeType
from apps.cardano import seed from apps.cardano import CURVE, seed
from apps.cardano.address import derive_address_and_node from apps.cardano.address import derive_address_and_node
from apps.common import layout, paths from apps.common import layout, paths
from apps.common.seed import remove_ed25519_prefix from apps.common.seed import remove_ed25519_prefix
@ -18,6 +18,7 @@ async def get_public_key(ctx, msg):
paths.validate_path_for_get_public_key, paths.validate_path_for_get_public_key,
keychain, keychain,
msg.address_n, msg.address_n,
CURVE,
slip44_id=1815, slip44_id=1815,
) )

View File

@ -1,7 +1,7 @@
from trezor import wire from trezor import wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.cardano import SEED_NAMESPACE from apps.cardano import CURVE, SEED_NAMESPACE
from apps.common import cache, mnemonic, storage from apps.common import cache, mnemonic, storage
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
@ -11,8 +11,8 @@ class Keychain:
self.path = path self.path = path
self.root = root self.root = root
def validate_path(self, checked_path: list): def validate_path(self, checked_path: list, checked_curve: str):
if checked_path[:2] != SEED_NAMESPACE[0]: if checked_curve != CURVE or checked_path[:2] != SEED_NAMESPACE:
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
def derive(self, node_path: list) -> bip32.HDNode: def derive(self, node_path: list) -> bip32.HDNode:
@ -40,8 +40,8 @@ async def get_keychain(ctx: wire.Context) -> Keychain:
root = bip32.from_mnemonic_cardano(mnemonic.restore(), passphrase) root = bip32.from_mnemonic_cardano(mnemonic.restore(), passphrase)
# derive the namespaced root node # derive the namespaced root node
for i in SEED_NAMESPACE[0]: for i in SEED_NAMESPACE:
root.derive_cardano(i) root.derive_cardano(i)
keychain = Keychain(SEED_NAMESPACE[0], root) keychain = Keychain(SEED_NAMESPACE, root)
return keychain return keychain

View File

@ -7,7 +7,7 @@ from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxRequest import CardanoTxRequest from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck from trezor.messages.MessageType import CardanoTxAck
from apps.cardano import cbor, seed from apps.cardano import CURVE, cbor, seed
from apps.cardano.address import ( from apps.cardano.address import (
derive_address_and_node, derive_address_and_node,
is_safe_output_address, is_safe_output_address,
@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
display_homescreen() display_homescreen()
for i in msg.inputs: for i in msg.inputs:
await validate_path(ctx, validate_full_path, keychain, i.address_n) await validate_path(ctx, validate_full_path, keychain, i.address_n, CURVE)
# sign the transaction bundle and prepare the result # sign the transaction bundle and prepare the result
transaction = Transaction( transaction = Transaction(

View File

@ -8,8 +8,8 @@ from apps.common import HARDENED
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
async def validate_path(ctx, validate_func, keychain, path, **kwargs): async def validate_path(ctx, validate_func, keychain, path, curve, **kwargs):
keychain.validate_path(path) keychain.validate_path(path, curve)
if not validate_func(path, **kwargs): if not validate_func(path, **kwargs):
await show_path_warning(ctx, path) await show_path_warning(ctx, path)

View File

@ -1,7 +1,7 @@
from trezor import ui, wire from trezor import ui, wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.common import cache, mnemonic, storage from apps.common import HARDENED, cache, mnemonic, storage
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
allow = list allow = list
@ -25,9 +25,11 @@ class Keychain:
del self.roots del self.roots
del self.seed del self.seed
def validate_path(self, checked_path: list): def validate_path(self, checked_path: list, checked_curve: str):
for curve, *path in self.namespaces: for curve, *path in self.namespaces:
if path == checked_path[: len(path)]: # TODO: check curve_name if path == checked_path[: len(path)] and curve == checked_curve:
if curve == "ed25519" and not _path_hardened(checked_path):
break
return return
raise wire.DataError("Forbidden key path") raise wire.DataError("Forbidden key path")
@ -67,6 +69,14 @@ async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
return keychain return keychain
def _path_hardened(path: list) -> bool:
# TODO: move to paths.py after #538 is fixed
for i in path:
if not (i & HARDENED):
return False
return True
@ui.layout_no_slide @ui.layout_no_slide
async def _compute_seed(ctx: wire.Context) -> bytes: async def _compute_seed(ctx: wire.Context) -> bytes:
passphrase = cache.get_passphrase() passphrase = cache.get_passphrase()

View File

@ -4,11 +4,13 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
from apps.ethereum.networks import all_slip44_ids_hardened from apps.ethereum.networks import all_slip44_ids_hardened
CURVE = "secp256k1"
def boot(): def boot():
ns = [] ns = []
for i in all_slip44_ids_hardened(): for i in all_slip44_ids_hardened():
ns.append(["secp256k1", HARDENED | 44, i]) ns.append([CURVE, HARDENED | 44, i])
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns) wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns)
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns) wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)

View File

@ -4,12 +4,12 @@ from trezor.messages.EthereumAddress import EthereumAddress
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ethereum import networks from apps.ethereum import CURVE, networks
from apps.ethereum.address import address_from_bytes, validate_full_path from apps.ethereum.address import address_from_bytes, validate_full_path
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
seckey = node.private_key() seckey = node.private_key()

View File

@ -2,12 +2,12 @@ from trezor.messages.EthereumPublicKey import EthereumPublicKey
from trezor.messages.HDNodeType import HDNodeType from trezor.messages.HDNodeType import HDNodeType
from apps.common import coins, layout, paths from apps.common import coins, layout, paths
from apps.ethereum import address from apps.ethereum import CURVE, address
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path( await paths.validate_path(
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE
) )
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -7,7 +7,7 @@ from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.signverify import split_message from apps.common.signverify import split_message
from apps.ethereum import address from apps.ethereum import CURVE, address
def message_digest(message): def message_digest(message):
@ -20,7 +20,9 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, address.validate_full_path, keychain, msg.address_n, CURVE
)
await require_confirm_sign_message(ctx, msg.message) await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -8,7 +8,7 @@ from trezor.messages.MessageType import EthereumTxAck
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.ethereum import address, tokens from apps.ethereum import CURVE, address, tokens
from apps.ethereum.address import validate_full_path from apps.ethereum.address import validate_full_path
from apps.ethereum.layout import ( from apps.ethereum.layout import (
require_confirm_data, require_confirm_data,
@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg) msg = sanitize(msg)
check(msg) check(msg)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
data_total = msg.data_length data_total = msg.data_length

View File

@ -3,9 +3,11 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
CURVE = "ed25519"
def boot(): def boot():
ns = [["ed25519", HARDENED | 44, HARDENED | 134]] ns = [[CURVE, HARDENED | 44, HARDENED | 134]]
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns) wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns)
wire.add(MessageType.LiskSignTx, __name__, "sign_tx", ns) wire.add(MessageType.LiskSignTx, __name__, "sign_tx", ns)

View File

@ -1,15 +1,16 @@
from trezor.messages.LiskAddress import LiskAddress from trezor.messages.LiskAddress import LiskAddress
from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path from .helpers import get_address_from_public_key, validate_full_path
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.lisk import CURVE
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, CURVE)
pubkey = node.public_key() pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker pubkey = pubkey[1:] # skip ed25519 pubkey marker
address = get_address_from_public_key(pubkey) address = get_address_from_public_key(pubkey)

View File

@ -1,14 +1,14 @@
from trezor.messages.LiskPublicKey import LiskPublicKey from trezor.messages.LiskPublicKey import LiskPublicKey
from .helpers import LISK_CURVE, validate_full_path
from apps.common import layout, paths from apps.common import layout, paths
from apps.lisk import CURVE
from apps.lisk.helpers import validate_full_path
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, CURVE)
pubkey = node.public_key() pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker pubkey = pubkey[1:] # skip ed25519 pubkey marker

View File

@ -2,8 +2,6 @@ from trezor.crypto.hashlib import sha256
from apps.common import HARDENED from apps.common import HARDENED
LISK_CURVE = "ed25519"
def get_address_from_public_key(pubkey): def get_address_from_public_key(pubkey):
pubkeyhash = sha256(pubkey).digest() pubkeyhash = sha256(pubkey).digest()

View File

@ -4,11 +4,11 @@ from trezor.messages.LiskMessageSignature import LiskMessageSignature
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.utils import HashWriter from trezor.utils import HashWriter
from .helpers import LISK_CURVE, validate_full_path
from apps.common import paths from apps.common import paths
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.signverify import split_message from apps.common.signverify import split_message
from apps.lisk import CURVE
from apps.lisk.helpers import validate_full_path
from apps.wallet.sign_tx.writers import write_varint from apps.wallet.sign_tx.writers import write_varint
@ -23,10 +23,10 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
await require_confirm_sign_message(ctx, msg.message) await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, CURVE)
seckey = node.private_key() seckey = node.private_key()
pubkey = node.public_key() pubkey = node.public_key()
pubkey = pubkey[1:] # skip ed25519 pubkey marker pubkey = pubkey[1:] # skip ed25519 pubkey marker

View File

@ -8,16 +8,12 @@ from trezor.messages.LiskSignedTx import LiskSignedTx
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.lisk import layout from apps.lisk import CURVE, layout
from apps.lisk.helpers import ( from apps.lisk.helpers import get_address_from_public_key, validate_full_path
LISK_CURVE,
get_address_from_public_key,
validate_full_path,
)
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
pubkey, seckey = _get_keys(keychain, msg) pubkey, seckey = _get_keys(keychain, msg)
transaction = _update_raw_tx(msg.transaction, pubkey) transaction = _update_raw_tx(msg.transaction, pubkey)
@ -41,7 +37,7 @@ async def sign_tx(ctx, msg, keychain):
def _get_keys(keychain, msg): def _get_keys(keychain, msg):
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, CURVE)
seckey = node.private_key() seckey = node.private_key()
pubkey = node.public_key() pubkey = node.public_key()

View File

@ -2,11 +2,13 @@ from trezor.messages.MoneroAddress import MoneroAddress
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.monero import misc from apps.monero import CURVE, misc
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, misc.validate_full_path, keychain, msg.address_n, CURVE
)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type) creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -20,7 +20,7 @@ from trezor.messages.MoneroGetTxKeyAck import MoneroGetTxKeyAck
from trezor.messages.MoneroGetTxKeyRequest import MoneroGetTxKeyRequest from trezor.messages.MoneroGetTxKeyRequest import MoneroGetTxKeyRequest
from apps.common import paths from apps.common import paths
from apps.monero import misc from apps.monero import CURVE, misc
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
from apps.monero.xmr.crypto import chacha_poly from apps.monero.xmr.crypto import chacha_poly
@ -30,7 +30,9 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain): async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, misc.validate_full_path, keychain, msg.address_n, CURVE
)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv) await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)

View File

@ -2,13 +2,15 @@ from trezor.messages.MoneroGetWatchKey import MoneroGetWatchKey
from trezor.messages.MoneroWatchKey import MoneroWatchKey from trezor.messages.MoneroWatchKey import MoneroWatchKey
from apps.common import paths from apps.common import paths
from apps.monero import misc from apps.monero import CURVE, misc
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain): async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, misc.validate_full_path, keychain, msg.address_n, CURVE
)
await confirms.require_confirm_watchkey(ctx) await confirms.require_confirm_watchkey(ctx)

View File

@ -8,7 +8,7 @@ from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAc
from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck
from apps.common import paths from apps.common import paths
from apps.monero import misc from apps.monero import CURVE, misc
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr import crypto, key_image, monero
from apps.monero.xmr.crypto import chacha_poly from apps.monero.xmr.crypto import chacha_poly
@ -47,7 +47,9 @@ class KeyImageSync:
async def _init_step(s, ctx, msg, keychain): async def _init_step(s, ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, misc.validate_full_path, keychain, msg.address_n, CURVE
)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -9,7 +9,7 @@ from trezor.messages.MoneroLiveRefreshStepAck import MoneroLiveRefreshStepAck
from trezor.messages.MoneroLiveRefreshStepRequest import MoneroLiveRefreshStepRequest from trezor.messages.MoneroLiveRefreshStepRequest import MoneroLiveRefreshStepRequest
from apps.common import paths from apps.common import paths
from apps.monero import misc from apps.monero import CURVE, misc
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr import crypto, key_image, monero
from apps.monero.xmr.crypto import chacha_poly from apps.monero.xmr.crypto import chacha_poly
@ -44,7 +44,9 @@ class LiveRefreshState:
async def _init_step( async def _init_step(
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
): ):
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, misc.validate_full_path, keychain, msg.address_n, CURVE
)
await confirms.require_confirm_live_refresh(ctx) await confirms.require_confirm_live_refresh(ctx)

View File

@ -4,7 +4,7 @@ Initializes a new transaction.
import gc import gc
from apps.monero import misc, signing from apps.monero import CURVE, misc, signing
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing.state import State from apps.monero.signing.state import State
from apps.monero.xmr import crypto, monero from apps.monero.xmr import crypto, monero
@ -24,7 +24,9 @@ async def init_transaction(
from apps.monero.signing import offloading_keys from apps.monero.signing import offloading_keys
from apps.common import paths from apps.common import paths
await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n) await paths.validate_path(
state.ctx, misc.validate_full_path, keychain, address_n, CURVE
)
state.creds = misc.get_creds(keychain, address_n, network_type) state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0 state.client_version = tsx_data.client_version or 0

View File

@ -3,11 +3,10 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
CURVE = "ed25519-keccak"
def boot(): def boot():
ns = [ ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]]
["ed25519-keccak", HARDENED | 44, HARDENED | 43],
["ed25519-keccak", HARDENED | 44, HARDENED | 1],
]
wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns) wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns)
wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns) wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns)

View File

@ -1,17 +1,19 @@
from trezor.messages.NEMAddress import NEMAddress from trezor.messages.NEMAddress import NEMAddress
from .helpers import NEM_CURVE, check_path, get_network_str
from .validators import validate_network
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.common.paths import validate_path from apps.common.paths import validate_path
from apps.nem import CURVE
from apps.nem.helpers import check_path, get_network_str
from apps.nem.validators import validate_network
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
network = validate_network(msg.network) network = validate_network(msg.network)
await validate_path(ctx, check_path, keychain, msg.address_n, network=network) await validate_path(
ctx, check_path, keychain, msg.address_n, CURVE, network=network
)
node = keychain.derive(msg.address_n, NEM_CURVE) node = keychain.derive(msg.address_n, CURVE)
address = node.nem_address(network) address = node.nem_address(network)
if msg.show_display: if msg.show_display:

View File

@ -5,7 +5,6 @@ from apps.common import HARDENED
NEM_NETWORK_MAINNET = const(0x68) NEM_NETWORK_MAINNET = const(0x68)
NEM_NETWORK_TESTNET = const(0x98) NEM_NETWORK_TESTNET = const(0x98)
NEM_NETWORK_MIJIN = const(0x60) NEM_NETWORK_MIJIN = const(0x60)
NEM_CURVE = "ed25519-keccak"
NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101) NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101)
NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801) NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801)

View File

@ -4,8 +4,8 @@ from trezor.messages.NEMSignTx import NEMSignTx
from apps.common import seed from apps.common import seed
from apps.common.paths import validate_path from apps.common.paths import validate_path
from apps.nem import mosaic, multisig, namespace, transfer from apps.nem import CURVE, mosaic, multisig, namespace, transfer
from apps.nem.helpers import NEM_CURVE, NEM_HASH_ALG, check_path from apps.nem.helpers import NEM_HASH_ALG, check_path
from apps.nem.validators import validate from apps.nem.validators import validate
@ -17,10 +17,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
check_path, check_path,
keychain, keychain,
msg.transaction.address_n, msg.transaction.address_n,
CURVE,
network=msg.transaction.network, network=msg.transaction.network,
) )
node = keychain.derive(msg.transaction.address_n, NEM_CURVE) node = keychain.derive(msg.transaction.address_n, CURVE)
if msg.multisig: if msg.multisig:
public_key = msg.multisig.signer public_key = msg.multisig.signer

View File

@ -3,8 +3,10 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
CURVE = "secp256k1"
def boot(): def boot():
ns = [["secp256k1", HARDENED | 44, HARDENED | 144]] ns = [[CURVE, HARDENED | 44, HARDENED | 144]]
wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns) wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns)
wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns) wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns)

View File

@ -3,11 +3,13 @@ from trezor.messages.RippleGetAddress import RippleGetAddress
from apps.common import paths from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ripple import helpers from apps.ripple import CURVE, helpers
async def get_address(ctx, msg: RippleGetAddress, keychain): async def get_address(ctx, msg: RippleGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
pubkey = node.public_key() pubkey = node.public_key()

View File

@ -6,14 +6,16 @@ from trezor.messages.RippleSignTx import RippleSignTx
from trezor.wire import ProcessError from trezor.wire import ProcessError
from apps.common import paths from apps.common import paths
from apps.ripple import helpers, layout from apps.ripple import CURVE, helpers, layout
from apps.ripple.serialize import serialize from apps.ripple.serialize import serialize
async def sign_tx(ctx, msg: RippleSignTx, keychain): async def sign_tx(ctx, msg: RippleSignTx, keychain):
validate(msg) validate(msg)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key()) source_address = helpers.address_from_public_key(node.public_key())

View File

@ -3,8 +3,10 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
CURVE = "ed25519"
def boot(): def boot():
ns = [["ed25519", HARDENED | 44, HARDENED | 148]] ns = [[CURVE, HARDENED | 44, HARDENED | 148]]
wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns) wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns)
wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns) wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns)

View File

@ -2,7 +2,6 @@ from micropython import const
from trezor.messages import MessageType from trezor.messages import MessageType
STELLAR_CURVE = "ed25519"
TX_TYPE = bytearray("\x00\x00\x00\x02") TX_TYPE = bytearray("\x00\x00\x00\x02")
# source: https://github.com/stellar/go/blob/3d2c1defe73dbfed00146ebe0e8d7e07ce4bb1b6/xdr/Stellar-transaction.x#L16 # source: https://github.com/stellar/go/blob/3d2c1defe73dbfed00146ebe0e8d7e07ce4bb1b6/xdr/Stellar-transaction.x#L16

View File

@ -3,13 +3,15 @@ from trezor.messages.StellarGetAddress import StellarGetAddress
from apps.common import paths, seed from apps.common import paths, seed
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.stellar import helpers from apps.stellar import CURVE, helpers
async def get_address(ctx, msg: StellarGetAddress, keychain): async def get_address(ctx, msg: StellarGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE) node = keychain.derive(msg.address_n, CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key()) pubkey = seed.remove_ed25519_prefix(node.public_key())
address = helpers.address_from_public_key(pubkey) address = helpers.address_from_public_key(pubkey)

View File

@ -5,8 +5,6 @@ from trezor.wire import ProcessError
from apps.common import HARDENED from apps.common import HARDENED
STELLAR_CURVE = "ed25519"
def public_key_from_address(address: str) -> bytes: def public_key_from_address(address: str) -> bytes:
"""Extracts public key from an address """Extracts public key from an address

View File

@ -8,14 +8,16 @@ from trezor.messages.StellarTxOpRequest import StellarTxOpRequest
from trezor.wire import ProcessError from trezor.wire import ProcessError
from apps.common import paths, seed from apps.common import paths, seed
from apps.stellar import consts, helpers, layout, writers from apps.stellar import CURVE, consts, helpers, layout, writers
from apps.stellar.operations import process_operation from apps.stellar.operations import process_operation
async def sign_tx(ctx, msg: StellarSignTx, keychain): async def sign_tx(ctx, msg: StellarSignTx, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE) node = keychain.derive(msg.address_n, CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key()) pubkey = seed.remove_ed25519_prefix(node.public_key())
if msg.num_operations == 0: if msg.num_operations == 0:

View File

@ -3,9 +3,11 @@ from trezor.messages import MessageType
from apps.common import HARDENED from apps.common import HARDENED
CURVE = "ed25519"
def boot(): def boot():
ns = [["ed25519", HARDENED | 44, HARDENED | 1729]] ns = [[CURVE, HARDENED | 44, HARDENED | 1729]]
wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns) wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns)
wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns) wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns)
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key", ns)

View File

@ -3,13 +3,15 @@ from trezor.messages.TezosAddress import TezosAddress
from apps.common import paths, seed from apps.common import paths, seed
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.tezos import helpers from apps.tezos import CURVE, helpers
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, CURVE)
pk = seed.remove_ed25519_prefix(node.public_key()) pk = seed.remove_ed25519_prefix(node.public_key())
pkh = hashlib.blake2b(pk, outlen=20).digest() pkh = hashlib.blake2b(pk, outlen=20).digest()

View File

@ -6,13 +6,15 @@ from trezor.utils import chunks
from apps.common import paths, seed from apps.common import paths, seed
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.tezos import helpers from apps.tezos import CURVE, helpers
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, CURVE)
pk = seed.remove_ed25519_prefix(node.public_key()) pk = seed.remove_ed25519_prefix(node.public_key())
pk_prefixed = helpers.base58_encode_check(pk, prefix=helpers.TEZOS_PUBLICKEY_PREFIX) pk_prefixed = helpers.base58_encode_check(pk, prefix=helpers.TEZOS_PUBLICKEY_PREFIX)

View File

@ -4,7 +4,6 @@ from trezor.crypto import base58
from apps.common import HARDENED from apps.common import HARDENED
TEZOS_CURVE = "ed25519"
TEZOS_AMOUNT_DIVISIBILITY = const(6) TEZOS_AMOUNT_DIVISIBILITY = const(6)
TEZOS_ED25519_ADDRESS_PREFIX = "tz1" TEZOS_ED25519_ADDRESS_PREFIX = "tz1"
TEZOS_ORIGINATED_ADDRESS_PREFIX = "KT1" TEZOS_ORIGINATED_ADDRESS_PREFIX = "KT1"

View File

@ -6,13 +6,15 @@ from trezor.messages.TezosSignedTx import TezosSignedTx
from apps.common import paths from apps.common import paths
from apps.common.writers import write_bytes, write_uint8 from apps.common.writers import write_bytes, write_uint8
from apps.tezos import helpers, layout from apps.tezos import CURVE, helpers, layout
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) await paths.validate_path(
ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE
)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, CURVE)
if msg.transaction is not None: if msg.transaction is not None:
to = _get_address_from_contract(msg.transaction.destination) to = _get_address_from_contract(msg.transaction.destination)

View File

@ -16,6 +16,7 @@ async def get_address(ctx, msg, keychain):
addresses.validate_full_path, addresses.validate_full_path,
keychain, keychain,
msg.address_n, msg.address_n,
coin.curve_name,
coin=coin, coin=coin,
script_type=msg.script_type, script_type=msg.script_type,
) )

View File

@ -24,6 +24,7 @@ async def sign_message(ctx, msg, keychain):
validate_full_path, validate_full_path,
keychain, keychain,
msg.address_n, msg.address_n,
coin.curve_name,
coin=coin, coin=coin,
script_type=msg.script_type, script_type=msg.script_type,
validate_script_type=False, validate_script_type=False,

View File

@ -0,0 +1,60 @@
from common import *
from apps.common import HARDENED
from apps.common.seed import Keychain, _path_hardened
from trezor import wire
class TestKeychain(unittest.TestCase):
def test_validate_path(self):
n = [
["ed25519", 44 | HARDENED, 134 | HARDENED],
["secp256k1", 44 | HARDENED, 11 | HARDENED],
]
k = Keychain(b"", n)
correct = (
([44 | HARDENED, 134 | HARDENED], "ed25519"),
([44 | HARDENED, 11 | HARDENED], "secp256k1"),
([44 | HARDENED, 11 | HARDENED, 12], "secp256k1"),
)
for c in correct:
self.assertEqual(None, k.validate_path(*c))
fails = [
([44 | HARDENED, 134], "ed25519"), # path does not match
([44 | HARDENED, 134], "secp256k1"), # curve and path does not match
([44 | HARDENED, 134 | HARDENED], "nist256p"), # curve not included
([44, 134], "ed25519"), # path does not match (non-hardened items)
([44 | HARDENED, 134 | HARDENED, 123], "ed25519"), # non-hardened item in ed25519
([44 | HARDENED, 13 | HARDENED], "secp256k1"), # invalid second item
]
for f in fails:
with self.assertRaises(wire.DataError):
k.validate_path(*f)
def test_validate_path_empty_namespace(self):
k = Keychain(b"", [["secp256k1"]])
correct = (
([], "secp256k1"),
([1, 2, 3, 4], "secp256k1"),
([44 | HARDENED, 11 | HARDENED], "secp256k1"),
([44 | HARDENED, 11 | HARDENED, 12], "secp256k1"),
)
for c in correct:
self.assertEqual(None, k.validate_path(*c))
with self.assertRaises(wire.DataError):
k.validate_path([1, 2, 3, 4], "ed25519")
k.validate_path([], "ed25519")
def test_path_hardened(self):
self.assertTrue(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED]))
self.assertTrue(_path_hardened([0 | HARDENED, ]))
self.assertFalse(_path_hardened([44, 44 | HARDENED, 0 | HARDENED]))
self.assertFalse(_path_hardened([0, ]))
self.assertFalse(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0]))
if __name__ == '__main__':
unittest.main()

View File

@ -1,7 +1,8 @@
from common import * from common import *
from ubinascii import unhexlify from ubinascii import unhexlify
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.nem.helpers import NEM_NETWORK_MAINNET, NEM_CURVE from apps.nem import CURVE
from apps.nem.helpers import NEM_NETWORK_MAINNET
class TestNemHDNode(unittest.TestCase): class TestNemHDNode(unittest.TestCase):
@ -81,7 +82,7 @@ class TestNemHDNode(unittest.TestCase):
child_num=0, child_num=0,
chain_code=bytearray(32), chain_code=bytearray(32),
private_key=private_key, private_key=private_key,
curve_name=NEM_CURVE curve_name=CURVE
) )
self.assertEqual(node.nem_address(NEM_NETWORK_MAINNET), test[2]) self.assertEqual(node.nem_address(NEM_NETWORK_MAINNET), test[2])
@ -222,7 +223,7 @@ class TestNemHDNode(unittest.TestCase):
child_num=0, child_num=0,
chain_code=bytearray(32), chain_code=bytearray(32),
private_key=private_key, private_key=private_key,
curve_name=NEM_CURVE curve_name=CURVE
) )
encrypted = node.nem_encrypt(unhexlify(test['public']), encrypted = node.nem_encrypt(unhexlify(test['public']),