# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from copy import deepcopy

from mnemonic import Mnemonic

from . import messages as proto, protobuf, tools
from .client import TrezorClient
from .tools import expect

EXPECTED_RESPONSES_CONTEXT_LINES = 3


class DebugLink:
    def __init__(self, transport, auto_interact=True):
        self.transport = transport
        self.allow_interactions = auto_interact

    def open(self):
        self.transport.begin_session()

    def close(self):
        self.transport.end_session()

    def _call(self, msg, nowait=False):
        self.transport.write(msg)
        if nowait:
            return None
        ret = self.transport.read()
        return ret

    def state(self):
        return self._call(proto.DebugLinkGetState())

    def read_pin(self):
        state = self.state()
        return state.pin, state.matrix

    def read_pin_encoded(self):
        return self.encode_pin(*self.read_pin())

    def encode_pin(self, pin, matrix=None):
        """Transform correct PIN according to the displayed matrix."""
        if matrix is None:
            _, matrix = self.read_pin()
        return "".join([str(matrix.index(p) + 1) for p in pin])

    def read_layout(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.layout

    def read_mnemonic_secret(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.mnemonic_secret

    def read_recovery_word(self):
        obj = self._call(proto.DebugLinkGetState())
        return (obj.recovery_fake_word, obj.recovery_word_pos)

    def read_reset_word(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.reset_word

    def read_reset_word_pos(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.reset_word_pos

    def read_reset_entropy(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.reset_entropy

    def read_passphrase_protection(self):
        obj = self._call(proto.DebugLinkGetState())
        return obj.passphrase_protection

    def input(self, word=None, button=None, swipe=None):
        if not self.allow_interactions:
            return
        decision = proto.DebugLinkDecision()
        if button is not None:
            decision.yes_no = button
        elif word is not None:
            decision.input = word
        elif swipe is not None:
            decision.up_down = swipe
        else:
            raise ValueError("You need to provide input data.")
        self._call(decision, nowait=True)

    def press_button(self, yes_no):
        self._call(proto.DebugLinkDecision(yes_no=yes_no), nowait=True)

    def press_yes(self):
        self.input(button=True)

    def press_no(self):
        self.input(button=False)

    def swipe_up(self):
        self.input(swipe=True)

    def swipe_down(self):
        self.input(swipe=False)

    def stop(self):
        self._call(proto.DebugLinkStop(), nowait=True)

    @expect(proto.DebugLinkMemory, field="memory")
    def memory_read(self, address, length):
        return self._call(proto.DebugLinkMemoryRead(address=address, length=length))

    def memory_write(self, address, memory, flash=False):
        self._call(
            proto.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
            nowait=True,
        )

    def flash_erase(self, sector):
        self._call(proto.DebugLinkFlashErase(sector=sector), nowait=True)


class NullDebugLink(DebugLink):
    def __init__(self):
        super().__init__(None)

    def open(self):
        pass

    def close(self):
        pass

    def _call(self, msg, nowait=False):
        if not nowait:
            if isinstance(msg, proto.DebugLinkGetState):
                return proto.DebugLinkState()
            else:
                raise RuntimeError("unexpected call to a fake debuglink")


class DebugUI:
    INPUT_FLOW_DONE = object()

    def __init__(self, debuglink: DebugLink):
        self.debuglink = debuglink
        self.pin = None
        self.passphrase = "sphinx of black quartz, judge my wov"
        self.input_flow = None

    def button_request(self, code):
        if self.input_flow is None:
            self.debuglink.press_yes()
        elif self.input_flow is self.INPUT_FLOW_DONE:
            raise AssertionError("input flow ended prematurely")
        else:
            try:
                self.input_flow.send(code)
            except StopIteration:
                self.input_flow = self.INPUT_FLOW_DONE

    def get_pin(self, code=None):
        if self.pin:
            return self.pin
        else:
            return self.debuglink.read_pin_encoded()

    def get_passphrase(self):
        return self.passphrase


class TrezorClientDebugLink(TrezorClient):
    # This class implements automatic responses
    # and other functionality for unit tests
    # for various callbacks, created in order
    # to automatically pass unit tests.
    #
    # This mixing should be used only for purposes
    # of unit testing, because it will fail to work
    # without special DebugLink interface provided
    # by the device.

    def __init__(self, transport, auto_interact=True):
        try:
            debug_transport = transport.find_debug()
            self.debug = DebugLink(debug_transport, auto_interact)
        except Exception:
            if not auto_interact:
                self.debug = NullDebugLink()
            else:
                raise

        self.ui = DebugUI(self.debug)

        self.in_with_statement = 0
        self.screenshot_id = 0

        self.filters = {}

        # Always press Yes and provide correct pin
        self.setup_debuglink(True, True)

        # Do not expect any specific response from device
        self.expected_responses = None
        self.current_response = None

        # Use blank passphrase
        self.set_passphrase("")
        super().__init__(transport, ui=self.ui)

    def open(self):
        super().open()
        self.debug.open()

    def close(self):
        self.debug.close()
        super().close()

    def set_filter(self, message_type, callback):
        self.filters[message_type] = callback

    def _filter_message(self, msg):
        message_type = msg.__class__
        callback = self.filters.get(message_type)
        if callable(callback):
            return callback(deepcopy(msg))
        else:
            return msg

    def set_input_flow(self, input_flow):
        if input_flow is None:
            self.ui.input_flow = None
            return

        if callable(input_flow):
            input_flow = input_flow()
        if not hasattr(input_flow, "send"):
            raise RuntimeError("input_flow should be a generator function")
        self.ui.input_flow = input_flow
        next(input_flow)  # can't send before first yield

    def __enter__(self):
        # For usage in with/expected_responses
        self.in_with_statement += 1
        return self

    def __exit__(self, _type, value, traceback):
        self.in_with_statement -= 1

        if _type is not None:
            # Another exception raised
            return False

        if self.expected_responses is None:
            # no need to check anything else
            return False

        # return isinstance(value, TypeError)
        # Evaluate missed responses in 'with' statement
        if self.current_response < len(self.expected_responses):
            self._raise_unexpected_response(None)

        # Cleanup
        self.expected_responses = None
        self.current_response = None
        return False

    def set_expected_responses(self, expected):
        if not self.in_with_statement:
            raise RuntimeError("Must be called inside 'with' statement")
        self.expected_responses = expected
        self.current_response = 0

    def setup_debuglink(self, button, pin_correct):
        # self.button = button  # True -> YES button, False -> NO button
        if pin_correct:
            self.ui.pin = None
        else:
            self.ui.pin = "444222"

    def set_passphrase(self, passphrase):
        self.ui.passphrase = Mnemonic.normalize_string(passphrase)

    def set_mnemonic(self, mnemonic):
        self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")

    def _raw_read(self):
        __tracebackhide__ = True  # for pytest # pylint: disable=W0612

        # if SCREENSHOT and self.debug:
        #     from PIL import Image

        #     layout = self.debug.state().layout
        #     im = Image.new("RGB", (128, 64))
        #     pix = im.load()
        #     for x in range(128):
        #         for y in range(64):
        #             rx, ry = 127 - x, 63 - y
        #             if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
        #                 pix[x, y] = (255, 255, 255)
        #     im.save("scr%05d.png" % self.screenshot_id)
        #     self.screenshot_id += 1

        resp = super()._raw_read()
        resp = self._filter_message(resp)
        self._check_request(resp)
        return resp

    def _raw_write(self, msg):
        return super()._raw_write(self._filter_message(msg))

    def _raise_unexpected_response(self, msg):
        __tracebackhide__ = True  # for pytest # pylint: disable=W0612

        start_at = max(self.current_response - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
        stop_at = min(
            self.current_response + EXPECTED_RESPONSES_CONTEXT_LINES + 1,
            len(self.expected_responses),
        )
        output = []
        output.append("Expected responses:")
        if start_at > 0:
            output.append("    (...{} previous responses omitted)".format(start_at))
        for i in range(start_at, stop_at):
            exp = self.expected_responses[i]
            prefix = "    " if i != self.current_response else ">>> "
            set_fields = {
                key: value
                for key, value in exp.__dict__.items()
                if value is not None and value != []
            }
            oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
            if len(oneline_str) < 60:
                output.append(
                    "{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
                )
            else:
                item = []
                item.append("{}{}(".format(prefix, exp.__class__.__name__))
                for key, value in set_fields.items():
                    item.append("{}    {}={!r}".format(prefix, key, value))
                item.append("{})".format(prefix))
                output.append("\n".join(item))
        if stop_at < len(self.expected_responses):
            omitted = len(self.expected_responses) - stop_at
            output.append("    (...{} following responses omitted)".format(omitted))

        output.append("")
        if msg is not None:
            output.append("Actually received:")
            output.append(protobuf.format_message(msg))
        else:
            output.append("This message was never received.")
        raise AssertionError("\n".join(output))

    def _check_request(self, msg):
        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
        if self.expected_responses is None:
            return

        if self.current_response >= len(self.expected_responses):
            raise AssertionError(
                "No more messages were expected, but we got:\n"
                + protobuf.format_message(msg)
            )

        expected = self.expected_responses[self.current_response]

        if msg.__class__ != expected.__class__:
            self._raise_unexpected_response(msg)

        for field, value in expected.__dict__.items():
            if value is None or value == []:
                continue
            if getattr(msg, field) != value:
                self._raise_unexpected_response(msg)

        self.current_response += 1

    def mnemonic_callback(self, _):
        word, pos = self.debug.read_recovery_word()
        if word != "":
            return word
        if pos != 0:
            return self.mnemonic[pos - 1]

        raise RuntimeError("Unexpected call")


@expect(proto.Success, field="message")
def load_device_by_mnemonic(
    client,
    mnemonic,
    pin,
    passphrase_protection,
    label,
    language="english",
    skip_checksum=False,
    expand=False,
):
    # Convert mnemonic to UTF8 NKFD
    mnemonic = Mnemonic.normalize_string(mnemonic)

    # Convert mnemonic to ASCII stream
    mnemonic = mnemonic.encode()

    m = Mnemonic("english")

    if expand:
        mnemonic = m.expand(mnemonic)

    if not skip_checksum and not m.check(mnemonic):
        raise ValueError("Invalid mnemonic checksum")

    if client.features.initialized:
        raise RuntimeError(
            "Device is initialized already. Call device.wipe() and try again."
        )

    resp = client.call(
        proto.LoadDevice(
            mnemonic=mnemonic,
            pin=pin,
            passphrase_protection=passphrase_protection,
            language=language,
            label=label,
            skip_checksum=skip_checksum,
        )
    )
    client.init_device()
    return resp


@expect(proto.Success, field="message")
def load_device_by_xprv(client, xprv, pin, passphrase_protection, label, language):
    if client.features.initialized:
        raise RuntimeError(
            "Device is initialized already. Call wipe_device() and try again."
        )

    if xprv[0:4] not in ("xprv", "tprv"):
        raise ValueError("Unknown type of xprv")

    if not 100 < len(xprv) < 112:  # yes this is correct in Python
        raise ValueError("Invalid length of xprv")

    node = proto.HDNodeType()
    data = tools.b58decode(xprv, None).hex()

    if data[90:92] != "00":
        raise ValueError("Contain invalid private key")

    checksum = (tools.btc_hash(bytes.fromhex(data[:156]))[:4]).hex()
    if checksum != data[156:]:
        raise ValueError("Checksum doesn't match")

    # version 0488ade4
    # depth 00
    # fingerprint 00000000
    # child_num 00000000
    # chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
    # privkey   00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
    # checksum e77e9d71

    node.depth = int(data[8:10], 16)
    node.fingerprint = int(data[10:18], 16)
    node.child_num = int(data[18:26], 16)
    node.chain_code = bytes.fromhex(data[26:90])
    node.private_key = bytes.fromhex(data[92:156])  # skip 0x00 indicating privkey

    resp = client.call(
        proto.LoadDevice(
            node=node,
            pin=pin,
            passphrase_protection=passphrase_protection,
            language=language,
            label=label,
        )
    )
    client.init_device()
    return resp


@expect(proto.Success, field="message")
def self_test(client):
    if client.features.bootloader_mode is not True:
        raise RuntimeError("Device must be in bootloader mode")

    return client.call(
        proto.SelfTest(
            payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
        )
    )