parent
b594248ac2
commit
0dff3853a7
@ -1,18 +1,12 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common import HARDENED
|
||||
from apps.ethereum.networks import all_slip44_ids_hardened
|
||||
|
||||
CURVE = "secp256k1"
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
ns = []
|
||||
for i in all_slip44_ids_hardened():
|
||||
ns.append([CURVE, HARDENED | 44, i])
|
||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns)
|
||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns)
|
||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)
|
||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns)
|
||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")
|
||||
|
@ -0,0 +1,64 @@
|
||||
from trezor import wire
|
||||
|
||||
from . import CURVE, networks
|
||||
|
||||
from apps.common import HARDENED, seed
|
||||
from apps.common.seed import get_keychain
|
||||
|
||||
if False:
|
||||
from typing import List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from protobuf import MessageType
|
||||
|
||||
from trezor.messages.EthereumSignTx import EthereumSignTx
|
||||
|
||||
from apps.common.seed import MsgOut, Handler, HandlerWithKeychain
|
||||
|
||||
class MsgWithAddressN(MessageType, Protocol):
|
||||
address_n = ... # type: List[int]
|
||||
|
||||
|
||||
async def from_address_n(ctx: wire.Context, address_n: List[int]) -> seed.Keychain:
|
||||
if len(address_n) < 2:
|
||||
raise wire.DataError("Forbidden key path")
|
||||
slip44_hardened = address_n[1]
|
||||
if slip44_hardened not in networks.all_slip44_ids_hardened():
|
||||
raise wire.DataError("Forbidden key path")
|
||||
namespace = CURVE, [44 | HARDENED, slip44_hardened]
|
||||
return await get_keychain(ctx, [namespace])
|
||||
|
||||
|
||||
def with_keychain_from_path(
|
||||
func: HandlerWithKeychain[MsgWithAddressN, MsgOut]
|
||||
) -> Handler[MsgWithAddressN, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut:
|
||||
keychain = await from_address_n(ctx, msg.address_n)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_keychain_from_chain_id(
|
||||
func: HandlerWithKeychain[EthereumSignTx, MsgOut]
|
||||
) -> Handler[EthereumSignTx, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut:
|
||||
if msg.chain_id is None:
|
||||
keychain = await from_address_n(ctx, msg.address_n)
|
||||
else:
|
||||
info = networks.by_chain_id(msg.chain_id)
|
||||
if info is None:
|
||||
raise wire.DataError("Unsupported chain id")
|
||||
|
||||
slip44 = info.slip44
|
||||
if networks.is_wanchain(msg.chain_id, msg.tx_type):
|
||||
slip44 = networks.SLIP44_WANCHAIN
|
||||
|
||||
namespace = CURVE, [44 | HARDENED, slip44 | HARDENED]
|
||||
keychain = await get_keychain(ctx, [namespace])
|
||||
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
|
||||
return wrapper
|
Loading…
Reference in new issue