from __future__ import annotations

import io
import typing as t
from hashlib import sha256

from trezorlib import cosi, definitions, messages, protobuf

from ...common import PRIVATE_KEYS_DEV


def make_network(
    chain_id: int = 0,
    slip44: int = 0,
    symbol: str = "FAKE",
    name: str = "Fake network",
) -> messages.EthereumNetworkInfo:
    return messages.EthereumNetworkInfo(
        chain_id=chain_id,
        slip44=slip44,
        symbol=symbol,
        name=name,
    )


def make_token(
    symbol: str = "FAKE",
    decimals: int = 18,
    address: bytes = b"",
    chain_id: int = 0,
    name: str = "Fake token",
) -> messages.EthereumTokenInfo:
    return messages.EthereumTokenInfo(
        symbol=symbol,
        decimals=decimals,
        address=address,
        chain_id=chain_id,
        name=name,
    )


def make_payload(
    data_type: messages.EthereumDefinitionType = messages.EthereumDefinitionType.NETWORK,
    timestamp: int = 0xFFFF_FFFF,
    message: messages.EthereumNetworkInfo
    | messages.EthereumTokenInfo
    | bytes = make_network(),
) -> bytes:
    if isinstance(message, bytes):
        message_bytes = message
    else:
        writer = io.BytesIO()
        protobuf.dump_message(writer, message)
        message_bytes = writer.getvalue()

    payload = definitions.DefinitionPayload(
        magic=b"trzd1",
        data_type=data_type,
        timestamp=timestamp,
        data=message_bytes,
    )
    return payload.build()


def sign_payload(
    payload: bytes,
    merkle_neighbors: list[bytes],
    threshold: int = 3,
) -> tuple[bytes, bytes]:
    digest = sha256(b"\x00" + payload).digest()
    merkle_proof = []
    for item in merkle_neighbors:
        left, right = min(digest, item), max(digest, item)
        digest = sha256(b"\x01" + left + right).digest()
        merkle_proof.append(digest)

    merkle_proof = len(merkle_proof).to_bytes(1, "little") + b"".join(merkle_proof)
    signature = cosi.sign_with_privkeys(digest, PRIVATE_KEYS_DEV[:threshold])
    sigmask = 0
    for i in range(threshold):
        sigmask |= 1 << i
    sigmask_byte = sigmask.to_bytes(1, "little")
    return merkle_proof, sigmask_byte + signature


def encode_network(
    network: messages.EthereumNetworkInfo | None = None,
    chain_id: int = 0,
    slip44: int = 0,
    symbol: str = "FAKE",
    name: str = "Fake network",
) -> bytes:
    if network is None:
        network = make_network(chain_id, slip44, symbol, name)
    payload = make_payload(
        data_type=messages.EthereumDefinitionType.NETWORK, message=network
    )
    proof, signature = sign_payload(payload, [])
    return payload + proof + signature


def encode_token(
    token: messages.EthereumTokenInfo | None = None,
    symbol: str = "FakeTok",
    decimals: int = 18,
    address: t.AnyStr = b"",
    chain_id: int = 0,
    name: str = "Fake token",
) -> bytes:
    if token is None:
        if isinstance(address, str):
            if address.startswith("0x"):
                address = address[2:]
            address = bytes.fromhex(address)  # type: ignore (typechecker is lying)
        token = make_token(symbol, decimals, address, chain_id, name)  # type: ignore (typechecker is lying)
    payload = make_payload(
        data_type=messages.EthereumDefinitionType.TOKEN, message=token
    )
    proof, signature = sign_payload(payload, [])
    return payload + proof + signature


def make_defs(
    network: bytes | None, token: bytes | None
) -> messages.EthereumDefinitions:
    return messages.EthereumDefinitions(
        encoded_network=network,
        encoded_token=token,
    )