# -*- coding: utf-8 -*-
#
# Copyright (C) 2014 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import re
import copy
import asyncio
import asyncio.subprocess

from gns3server.utils.asyncio import asyncio_ensure_future

import logging
log = logging.getLogger(__name__)

READ_SIZE = 4096


class AsyncioRawCommandServer:
    """
    Expose a process on the network his stdoud and stdin will be forward
    on network
    """

    def __init__(self, command, replaces=[]):
        """
        :param command: Command to run
        :param replaces: List of tuple to replace in the output ex: [(b":8080", b":6000")]
        """
        self._command = command
        self._replaces = replaces
        # We limit number of process
        self._lock = asyncio.Semaphore(value=4)

    @asyncio.coroutine
    def run(self, network_reader, network_writer):
        yield from self._lock.acquire()
        process = yield from asyncio.subprocess.create_subprocess_exec(*self._command,
                                                                       stdout=asyncio.subprocess.PIPE,
                                                                       stderr=asyncio.subprocess.STDOUT,
                                                                       stdin=asyncio.subprocess.PIPE)
        try:
            yield from self._process(network_reader, network_writer, process.stdout, process.stdin)
        except ConnectionResetError:
            network_writer.close()
        if process.returncode is None:
            process.kill()
        yield from process.wait()
        self._lock.release()

    @asyncio.coroutine
    def _process(self, network_reader, network_writer, process_reader, process_writer):
        replaces = []
        # Server host from the client point of view
        host = network_writer.transport.get_extra_info("sockname")[0]
        for replace in self._replaces:
            if b'{{HOST}}' in replace[1]:
                replaces.append((replace[0], replace[1].replace(b'{{HOST}}', host.encode()), ))
            else:
                replaces.append((replace[0], replace[1], ))

        network_read = asyncio_ensure_future(network_reader.read(READ_SIZE))
        reader_read = asyncio_ensure_future(process_reader.read(READ_SIZE))
        timeout = 30

        while True:
            done, pending = yield from asyncio.wait(
                [
                    network_read,
                    reader_read
                ],
                timeout=timeout,
                return_when=asyncio.FIRST_COMPLETED)
            if len(done) == 0:
                raise ConnectionResetError()
            for coro in done:
                data = coro.result()
                if coro == network_read:
                    if network_reader.at_eof():
                        raise ConnectionResetError()

                    network_read = asyncio_ensure_future(network_reader.read(READ_SIZE))

                    process_writer.write(data)
                    yield from process_writer.drain()
                elif coro == reader_read:
                    if process_reader.at_eof():
                        raise ConnectionResetError()

                    reader_read = asyncio_ensure_future(process_reader.read(READ_SIZE))

                    for replace in replaces:
                        data = data.replace(replace[0], replace[1])
                    timeout = 2  # We reduce the timeout when the process start to return stuff to avoid problem with server not closing the connection

                    network_writer.write(data)
                    yield from network_writer.drain()


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    loop = asyncio.get_event_loop()

    command = ["nc", "localhost", "80"]
    server = AsyncioRawCommandServer(command, replaces=[(b"work", b"{{HOST}}", )])
    coro = asyncio.start_server(server.run, '0.0.0.0', 4444, loop=loop)
    s = loop.run_until_complete(coro)

    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass
    # Close the server
    s.close()
    loop.run_until_complete(s.wait_closed())
    loop.close()