From f5e5cf505917e294f0df5547f2371bff095c6648 Mon Sep 17 00:00:00 2001 From: Julien Duponchelle Date: Tue, 19 Apr 2016 15:35:50 +0200 Subject: [PATCH] Save the list of compute node Fix #494 --- gns3server/controller/__init__.py | 50 +++++++++++++++++++++++ gns3server/controller/compute.py | 14 +++++++ gns3server/run.py | 2 +- gns3server/web/web_server.py | 5 +++ tests/conftest.py | 11 +++++- tests/controller/test_controller.py | 61 +++++++++++++++++++++++++++-- 6 files changed, 137 insertions(+), 6 deletions(-) diff --git a/gns3server/controller/__init__.py b/gns3server/controller/__init__.py index 557ad20f..82094687 100644 --- a/gns3server/controller/__init__.py +++ b/gns3server/controller/__init__.py @@ -15,12 +15,19 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os +import sys +import json import asyncio import aiohttp from ..config import Config from .project import Project from .compute import Compute +from ..version import __version__ + +import logging +log = logging.getLogger(__name__) class Controller: @@ -30,6 +37,48 @@ class Controller: self._computes = {} self._projects = {} + if sys.platform.startswith("win"): + config_path = os.path.join(os.path.expandvars("%APPDATA%"), "GNS3") + else: + config_path = os.path.join(os.path.expanduser("~"), ".config", "GNS3") + self._config_file = os.path.join(config_path, "gns3_controller.conf") + + def save(self): + """ + Save the controller configuration on disk + """ + data = { + "computes": [ { + "host": c.host, + "port": c.port, + "protocol": c.protocol, + "user": c.user, + "password": c.password, + "compute_id": c.id + } for c in self._computes.values() ], + "version": __version__ + } + os.makedirs(os.path.dirname(self._config_file), exist_ok=True) + with open(self._config_file, 'w+') as f: + json.dump(data, f, indent=4) + + @asyncio.coroutine + def load(self): + """ + Reload the controller configuration from disk + """ + if not os.path.exists(self._config_file): + return + try: + with open(self._config_file) as f: + data = json.load(f) + except OSError as e: + log.critical("Can not load %s: %s", self._config_file, str(e)) + return + for c in data["computes"]: + compute_id = c.pop("compute_id") + yield from self.addCompute(compute_id, **c) + def isEnabled(self): """ :returns: True if current instance is the controller @@ -47,6 +96,7 @@ class Controller: if compute_id not in self._computes: compute = Compute(compute_id=compute_id, controller=self, **kwargs) self._computes[compute_id] = compute + self.save() return self._computes[compute_id] @property diff --git a/gns3server/controller/compute.py b/gns3server/controller/compute.py index e35d0d58..17f58537 100644 --- a/gns3server/controller/compute.py +++ b/gns3server/controller/compute.py @@ -87,6 +87,20 @@ class Compute: """ return self._host + @property + def port(self): + """ + :returns: Compute port (integer) + """ + return self._port + + @property + def protocol(self): + """ + :returns: Compute protocol (string) + """ + return self._protocol + @property def user(self): return self._user diff --git a/gns3server/run.py b/gns3server/run.py index 5eb87609..2bfb0944 100644 --- a/gns3server/run.py +++ b/gns3server/run.py @@ -26,7 +26,6 @@ import datetime import sys import locale import argparse -import asyncio from gns3server.web.web_server import WebServer from gns3server.web.logger import init_logger @@ -35,6 +34,7 @@ from gns3server.config import Config from gns3server.compute.project import Project from gns3server.crash_report import CrashReport + import logging log = logging.getLogger(__name__) diff --git a/gns3server/web/web_server.py b/gns3server/web/web_server.py index 348ede1d..d43c88c1 100644 --- a/gns3server/web/web_server.py +++ b/gns3server/web/web_server.py @@ -34,6 +34,8 @@ from .request_handler import RequestHandler from ..config import Config from ..compute import MODULES from ..compute.port_manager import PortManager +from ..controller import Controller + # do not delete this import import gns3server.handlers @@ -198,6 +200,9 @@ class WebServer: # Asyncio will raise error if coroutine is not called self._loop.set_debug(True) + if server_config.getboolean("controller"): + asyncio.async(Controller.instance().load()) + for key, val in os.environ.items(): log.debug("ENV %s=%s", key, val) diff --git a/tests/conftest.py b/tests/conftest.py index fba38c1f..85f2c194 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,9 +174,16 @@ def ethernet_device(): @pytest.fixture -def controller(): +def controller_config_path(tmpdir): + return str(tmpdir / "config" / "gns3_controller.conf") + + +@pytest.fixture +def controller(tmpdir, controller_config_path): Controller._instance = None - return Controller.instance() + controller = Controller.instance() + controller._config_file = controller_config_path + return controller @pytest.fixture diff --git a/tests/controller/test_controller.py b/tests/controller/test_controller.py index 90fd057a..3c3d5f64 100644 --- a/tests/controller/test_controller.py +++ b/tests/controller/test_controller.py @@ -15,8 +15,10 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import pytest +import os import uuid +import json +import pytest import aiohttp from unittest.mock import MagicMock @@ -25,6 +27,44 @@ from gns3server.controller import Controller from gns3server.controller.compute import Compute from gns3server.controller.project import Project from gns3server.config import Config +from gns3server.version import __version__ + + +def test_save(controller, controller_config_path): + controller.save() + assert os.path.exists(controller_config_path) + with open(controller_config_path) as f: + data = json.load(f) + assert data["computes"] == [] + assert data["version"] == __version__ + + +def test_load(controller, controller_config_path, async_run): + controller.save() + with open(controller_config_path) as f: + data = json.load(f) + data["computes"] = [ + { + "host": "localhost", + "port": 8000, + "protocol": "http", + "user": "admin", + "password": "root", + "compute_id": "test1" + } + ] + with open(controller_config_path, "w+") as f: + json.dump(data, f) + async_run(controller.load()) + assert len(controller.computes) == 1 + assert controller.computes["test1"].__json__() == { + "compute_id": "test1", + "connected": False, + "host": "localhost", + "port": 8000, + "protocol": "http", + "user": "admin" + } def test_isEnabled(controller): @@ -34,7 +74,7 @@ def test_isEnabled(controller): assert controller.isEnabled() -def test_addCompute(controller, async_run): +def test_addCompute(controller, controller_config_path, async_run): async_run(controller.addCompute("test1")) assert len(controller.computes) == 1 async_run(controller.addCompute("test1")) @@ -42,9 +82,24 @@ def test_addCompute(controller, async_run): async_run(controller.addCompute("test2")) assert len(controller.computes) == 2 +def test_addComputeConfigFile(controller, controller_config_path, async_run): + async_run(controller.addCompute("test1")) + assert len(controller.computes) == 1 + with open(controller_config_path) as f: + data = json.load(f) + assert data["computes"] == [ + { + 'compute_id': 'test1', + 'host': 'localhost', + 'port': 8000, + 'protocol': 'http', + 'user': None, + 'password': None + } + ] + def test_getCompute(controller, async_run): - compute = async_run(controller.addCompute("test1")) assert controller.getCompute("test1") == compute