#
# Copyright (C) 2014 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import socket
import asyncio
import asyncio.subprocess
import struct

import logging

log = logging.getLogger(__name__)

# Mostly from https://code.google.com/p/miniboa/source/browse/trunk/miniboa/telnet.py

# Telnet Commands
SE = 240  # End of sub-negotiation parameters
NOP = 241  # No operation
DATMK = 242  # Data stream portion of a sync.
BREAK = 243  # NVT Character BRK
IP = 244  # Interrupt Process
AO = 245  # Abort Output
AYT = 246  # Are you there
EC = 247  # Erase Character
EL = 248  # Erase Line
GA = 249  # The Go Ahead Signal
SB = 250  # Sub-option to follow
WILL = 251  # Will; request or confirm option begin
WONT = 252  # Wont; deny option request
DO = 253  # Do = Request or confirm remote option
DONT = 254  # Don't = Demand or confirm option halt
IAC = 255  # Interpret as Command
SEND = 1  # Sub-process negotiation SEND command
IS = 0  # Sub-process negotiation IS command

# Telnet Options
BINARY = 0  # Transmit Binary
ECHO = 1  # Echo characters back to sender
RECON = 2  # Reconnection
SGA = 3  # Suppress Go-Ahead
TMARK = 6  # Timing Mark
TTYPE = 24  # Terminal Type
NAWS = 31  # Negotiate About Window Size
LINEMO = 34  # Line Mode

READ_SIZE = 1024


class TelnetConnection:
    """Default implementation of telnet connection which may but may not be used."""

    def __init__(self, reader, writer, window_size_changed_callback=None):
        self.is_closing = False
        self._reader = reader
        self._writer = writer
        self._window_size_changed_callback = window_size_changed_callback

    @property
    def reader(self):
        return self._reader

    @property
    def writer(self):
        return self._writer

    async def connected(self):
        """Method called when client is connected"""
        pass

    async def disconnected(self):
        """Method called when client is disconnecting"""
        pass

    async def window_size_changed(self, columns, rows):
        """Method called when window size changed, only can occur when
        `naws` flag is enable in server configuration."""

        if self._window_size_changed_callback:
            await self._window_size_changed_callback(columns, rows)

    async def feed(self, data):
        """
        Handles incoming data
        :return:
        """

    def send(self, data):
        """
        Sending data back to client
        :return:
        """
        data = data.decode().replace("\n", "\r\n")
        self.writer.write(data.encode())

    def close(self):
        """
        Closes current connection
        :return:
        """
        self.is_closing = True


class AsyncioTelnetServer:
    MAX_NEGOTIATION_READ = 10

    def __init__(
        self,
        reader=None,
        writer=None,
        binary=True,
        echo=False,
        naws=False,
        window_size_changed_callback=None,
        connection_factory=None,
    ):
        """
        Initializes telnet server
        :param naws when True make a window size negotiation
        :param connection_factory: when set it's possible to inject own implementation of connection
        """
        assert connection_factory is None or (
            connection_factory is not None and reader is None and writer is None
        ), "Please use either reader and writer either connection_factory, otherwise duplicate data may be produced."

        self._reader = reader
        self._writer = writer
        self._connections = dict()
        self._lock = asyncio.Lock()
        self._reader_process = None
        self._current_read = None
        self._window_size_changed_callback = window_size_changed_callback

        self._binary = binary
        # If echo is true when the client send data
        # the data is echo on his terminal by telnet otherwise
        # it's our job (or the wrapped app) to send back the data
        self._echo = echo
        self._naws = naws

        def default_connection_factory(reader, writer, window_size_changed_callback):
            return TelnetConnection(reader, writer, window_size_changed_callback)

        if connection_factory is None:
            connection_factory = default_connection_factory

        self._connection_factory = connection_factory

    @staticmethod
    async def write_client_intro(writer, echo=False):
        # Send initial telnet session opening
        if echo:
            writer.write(bytes([IAC, WILL, ECHO]))
        else:
            writer.write(bytes([IAC, WONT, ECHO, IAC, DONT, ECHO]))
        await writer.drain()

    async def _write_intro(self, writer, binary=False, echo=False, naws=False):
        # Send initial telnet session opening
        if echo:
            writer.write(bytes([IAC, WILL, ECHO]))
        else:
            writer.write(bytes([IAC, WONT, ECHO, IAC, DONT, ECHO]))

        if binary:
            writer.write(bytes([IAC, WILL, SGA, IAC, WILL, BINARY, IAC, DO, BINARY]))
        else:
            writer.write(bytes([IAC, WONT, SGA, IAC, DONT, SGA, IAC, WONT, BINARY, IAC, DONT, BINARY]))

        if naws:
            writer.write(bytes([IAC, DO, NAWS]))
        await writer.drain()

    async def run(self, network_reader, network_writer):

        sock = network_writer.get_extra_info("socket")
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        # 60 sec keep alives, close tcp session after 4 missed
        # Will keep a firewall from aging out telnet console.
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60)
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
        #log.debug("New connection from {}".format(sock.getpeername()))

        # Keep track of connected clients
        connection = self._connection_factory(network_reader, network_writer, self._window_size_changed_callback)
        self._connections[network_writer] = connection

        try:
            await self._write_intro(network_writer, echo=self._echo, binary=self._binary, naws=self._naws)
            await connection.connected()
            await self._process(network_reader, network_writer, connection)
        except ConnectionError:
            async with self._lock:
                network_writer.close()
                # await network_writer.wait_closed()  # this doesn't work in Python 3.6
                if self._reader_process == network_reader:
                    self._reader_process = None
                    # Cancel current read from this reader
                    if self._current_read is not None:
                        self._current_read.cancel()

            await connection.disconnected()
            del self._connections[network_writer]

    async def close(self):
        for writer, connection in self._connections.items():
            try:
                writer.write_eof()
                await writer.drain()
                writer.close()
                # await writer.wait_closed()  # this doesn't work in Python 3.6
            except (AttributeError, ConnectionError):
                continue

    async def client_connected_hook(self):
        pass

    async def _get_reader(self, network_reader):
        """
        Get a reader or None if another reader is already reading.
        """
        async with self._lock:
            if self._reader_process is None:
                self._reader_process = network_reader
            if self._reader:
                if self._reader_process == network_reader:
                    self._current_read = asyncio.ensure_future(self._reader.read(READ_SIZE))
                    return self._current_read
        return None

    async def _process(self, network_reader, network_writer, connection):
        network_read = asyncio.ensure_future(network_reader.read(READ_SIZE))
        reader_read = await self._get_reader(network_reader)

        while True:
            if reader_read is None:
                reader_read = await self._get_reader(network_reader)
            if reader_read is None:
                done, pending = await asyncio.wait(
                    [
                        network_read,
                    ],
                    timeout=1,
                    return_when=asyncio.FIRST_COMPLETED,
                )
            else:
                done, pending = await asyncio.wait([network_read, reader_read], return_when=asyncio.FIRST_COMPLETED)
            for coro in done:
                data = coro.result()
                if coro == network_read:
                    if network_reader.at_eof():
                        raise ConnectionResetError()

                    network_read = asyncio.ensure_future(network_reader.read(READ_SIZE))

                    if IAC in data:
                        data = await self._IAC_parser(data, network_reader, network_writer, connection)

                    if len(data) == 0:
                        continue

                    if not self._binary:
                        data = data.replace(b"\r\n", b"\n")

                    if self._writer:
                        self._writer.write(data)
                        await self._writer.drain()

                    await connection.feed(data)
                    if connection.is_closing:
                        raise ConnectionResetError()

                elif coro == reader_read:
                    if self._reader and self._reader.at_eof():
                        raise ConnectionResetError()

                    reader_read = await self._get_reader(network_reader)

                    # Replicate the output on all clients
                    for connection_key in list(self._connections.keys()):
                        client_info = connection_key.get_extra_info("socket").getpeername()
                        connection = self._connections[connection_key]

                        try:
                            connection.writer.write(data)
                            await asyncio.wait_for(connection.writer.drain(), timeout=10)
                        except:
                            log.debug(f"Timeout while sending data to client: {client_info}, closing and removing from connection table.")
                            connection.close()
                            del self._connections[connection_key]

    async def _read(self, cmd, buffer, location, reader):
        """ Reads next op from the buffer or reader"""
        try:
            op = buffer[location]
            cmd.append(op)
            return op
        except IndexError:
            op = await reader.read(1)
            buffer.extend(op)
            cmd.append(buffer[location])
            return op

    async def _negotiate(self, data, connection):
        """ Performs negotiation commands"""

        command, payload = data[0], data[1:]
        if command == NAWS:
            if len(payload) == 4:
                columns, rows = struct.unpack("!HH", bytes(payload))
                await connection.window_size_changed(columns, rows)
            else:
                log.warning("Wrong number of NAWS bytes")
        else:
            log.debug("Not supported negotiation sequence, received {} bytes", len(data))

    async def _IAC_parser(self, buf, network_reader, network_writer, connection):
        """
        Processes and removes any Telnet commands from the buffer.

        :param buf: buffer
        :returns: buffer minus Telnet commands
        """

        skip_to = 0
        while True:
            # Locate an IAC to process
            iac_loc = buf.find(IAC, skip_to)
            if iac_loc < 0:
                break

            # Get the TELNET command
            iac_cmd = bytearray([IAC])
            try:
                iac_cmd.append(buf[iac_loc + 1])
            except IndexError:
                d = await network_reader.read(1)
                buf.extend(d)
                iac_cmd.append(buf[iac_loc + 1])

            # Is this just a 2-byte TELNET command?
            if iac_cmd[1] not in [WILL, WONT, DO, DONT, SB]:
                if iac_cmd[1] == AYT:
                    log.debug("Telnet server received Are-You-There (AYT)")
                    network_writer.write(b"\r\nYour Are-You-There received. I am here.\r\n")
                elif iac_cmd[1] == IAC:
                    # It's data, not an IAC
                    iac_cmd.pop()
                    # This prevents the 0xff from being
                    # interrupted as yet another IAC
                    skip_to = iac_loc + 1
                    log.debug("Received IAC IAC")
                elif iac_cmd[1] == NOP:
                    pass
                else:
                    log.debug("Unhandled telnet command: " "{:#x} {:#x}".format(*iac_cmd))
            elif iac_cmd[1] == SB:  # starts negotiation commands
                negotiation = []
                for pos in range(2, self.MAX_NEGOTIATION_READ):
                    op = await self._read(iac_cmd, buf, iac_loc + pos, network_reader)
                    negotiation.append(op)
                    if op == SE:
                        # ends negotiation commands
                        break

                # SE command is followed by IAC, remove the last two operations from stack
                await self._negotiate(negotiation[0:-2], connection)

            # This must be a 3-byte TELNET command
            else:
                try:
                    iac_cmd.append(buf[iac_loc + 2])
                except IndexError:
                    d = await network_reader.read(1)
                    buf.extend(d)
                    iac_cmd.append(buf[iac_loc + 2])
                # We do ECHO, SGA, and BINARY. Period.
                if iac_cmd[1] == DO:
                    if iac_cmd[2] not in [ECHO, SGA, BINARY]:
                        network_writer.write(bytes([IAC, WONT, iac_cmd[2]]))
                        log.debug(f"Telnet WON'T {iac_cmd[2]:#x}")
                    else:
                        if iac_cmd[2] == SGA:
                            if self._binary:
                                network_writer.write(bytes([IAC, WILL, iac_cmd[2]]))
                            else:
                                network_writer.write(bytes([IAC, WONT, iac_cmd[2]]))
                                log.debug(f"Telnet WON'T {iac_cmd[2]:#x}")

                elif iac_cmd[1] == DONT:
                    log.debug("Unhandled DONT telnet command: " "{:#x} {:#x} {:#x}".format(*iac_cmd))
                elif iac_cmd[1] == WILL:
                    if iac_cmd[2] not in [BINARY, NAWS]:
                        log.debug("Unhandled WILL telnet command: " "{:#x} {:#x} {:#x}".format(*iac_cmd))
                elif iac_cmd[1] == WONT:
                    log.debug("Unhandled WONT telnet command: " "{:#x} {:#x} {:#x}".format(*iac_cmd))
                else:
                    log.debug("Unhandled telnet command: " "{:#x} {:#x} {:#x}".format(*iac_cmd))

            # Remove the entire TELNET command from the buffer
            buf = buf.replace(iac_cmd, b"", 1)

            await network_writer.drain()

        # Return the new copy of the buffer, minus telnet commands
        return buf


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    loop = asyncio.get_event_loop()

    process = loop.run_until_complete(
        asyncio.ensure_future(
            asyncio.subprocess.create_subprocess_exec(
                "/bin/sh",
                "-i",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.STDOUT,
                stdin=asyncio.subprocess.PIPE,
            )
        )
    )
    server = AsyncioTelnetServer(reader=process.stdout, writer=process.stdin, binary=False, echo=False)

    coro = asyncio.start_server(server.run, "127.0.0.1", 4444)
    s = loop.run_until_complete(coro)

    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass
    # Close the server
    s.close()
    loop.run_until_complete(s.wait_closed())
    loop.close()