diff --git a/gns3server/controller/project.py b/gns3server/controller/project.py index eb7abc9c..1d99d451 100644 --- a/gns3server/controller/project.py +++ b/gns3server/controller/project.py @@ -32,6 +32,7 @@ from .topology import project_to_topology, load_topology from .udp_link import UDPLink from ..config import Config from ..utils.path import check_path_allowed, get_default_project_directory +from ..utils.asyncio.pool import Pool from .export_project import export_project from .import_project import import_project @@ -61,14 +62,14 @@ class Project: :param status: Status of the project (opened / closed) """ - def __init__(self, name=None, project_id=None, path=None, controller=None, status="opened", filename=None, auto_start=False): + def __init__(self, name=None, project_id=None, path=None, controller=None, status="opened", filename=None, auto_start=False, auto_open=False, auto_close=False): self._controller = controller assert name is not None self._name = name - self._auto_start = False - self._auto_close = False - self._auto_open = False + self._auto_start = auto_start + self._auto_close = auto_close + self._auto_open = auto_open self._status = status # Disallow overwrite of existing project @@ -103,9 +104,8 @@ class Project: @asyncio.coroutine def update(self, **kwargs): """ - Update the node on the compute server - - :param kwargs: Node properties + Update the project + :param kwargs: Project properties """ old_json = self.__json__() @@ -492,6 +492,7 @@ class Project: @asyncio.coroutine def close(self, ignore_notification=False): + yield from self.stop_all() for compute in self._project_created_on_compute: yield from compute.post("/projects/{}/close".format(self._id)) self._cleanPictures() @@ -580,6 +581,10 @@ class Project: for drawing_data in topology.get("drawings", []): drawing = yield from self.add_drawing(**drawing_data) + # Should we start the nodes when project is open + if self._auto_start: + yield from self.start_all() + @open_required @asyncio.coroutine def duplicate(self, name=None, location=None): @@ -626,6 +631,36 @@ class Project: except OSError as e: raise aiohttp.web.HTTPInternalServerError(text="Could not write topology: {}".format(e)) + @asyncio.coroutine + def start_all(self): + """ + Start all nodes + """ + pool = Pool(concurrency=3) + for node in self.nodes.values(): + pool.append(node.start) + yield from pool.join() + + @asyncio.coroutine + def stop_all(self): + """ + Stop all nodes + """ + pool = Pool(concurrency=3) + for node in self.nodes.values(): + pool.append(node.stop) + yield from pool.join() + + @asyncio.coroutine + def suspend_all(self): + """ + Suspend all nodes + """ + pool = Pool(concurrency=3) + for node in self.nodes.values(): + pool.append(node.suspend) + yield from pool.join() + def __json__(self): return { diff --git a/gns3server/controller/topology.py b/gns3server/controller/topology.py index 46a235dc..b2f1cc52 100644 --- a/gns3server/controller/topology.py +++ b/gns3server/controller/topology.py @@ -53,6 +53,8 @@ def project_to_topology(project): "project_id": project.id, "name": project.name, "auto_start": project.auto_start, + "auto_open": project.auto_open, + "auto_close": project.auto_close, "topology": { "nodes": [], "links": [], diff --git a/gns3server/handlers/api/controller/node_handler.py b/gns3server/handlers/api/controller/node_handler.py index 0368dab4..ff404585 100644 --- a/gns3server/handlers/api/controller/node_handler.py +++ b/gns3server/handlers/api/controller/node_handler.py @@ -20,7 +20,6 @@ import aiohttp from gns3server.web.route import Route from gns3server.controller import Controller -from gns3server.utils.asyncio.pool import Pool from gns3server.schemas.node import ( NODE_OBJECT_SCHEMA, @@ -107,10 +106,7 @@ class NodeHandler: def start_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) - pool = Pool(concurrency=3) - for node in project.nodes.values(): - pool.append(node.start) - yield from pool.join() + yield from project.start_all() response.set_status(204) @Route.post( @@ -128,10 +124,7 @@ class NodeHandler: def stop_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) - pool = Pool(concurrency=3) - for node in project.nodes.values(): - pool.append(node.stop) - yield from pool.join() + yield from project.stop_all() response.set_status(204) @Route.post( @@ -149,10 +142,7 @@ class NodeHandler: def suspend_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) - pool = Pool(concurrency=3) - for node in project.nodes.values(): - pool.append(node.suspend) - yield from pool.join() + yield from project.suspend_all() response.set_status(204) @Route.post( @@ -170,13 +160,8 @@ class NodeHandler: def reload_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) - pool = Pool(concurrency=3) - for node in project.nodes.values(): - pool.append(node.stop) - yield from pool.join() - for node in project.nodes.values(): - pool.append(node.start) - yield from pool.join() + yield from project.stop_all() + yield from project.start_all() response.set_status(204) @Route.post( diff --git a/gns3server/schemas/topology.py b/gns3server/schemas/topology.py index 37025f43..f3b76e2e 100644 --- a/gns3server/schemas/topology.py +++ b/gns3server/schemas/topology.py @@ -45,6 +45,14 @@ TOPOLOGY_SCHEMA = { "description": "Start the topology when opened", "type": "boolean" }, + "auto_close": { + "description": "Close the topology when no client is connected", + "type": "boolean" + }, + "auto_open": { + "description": "Open the topology with GNS3", + "type": "boolean" + }, "revision": { "description": "Version of the .gns3 specification.", "type": "integer" diff --git a/gns3server/templates/project.html b/gns3server/templates/project.html index 1906273e..9b6bc3ab 100644 --- a/gns3server/templates/project.html +++ b/gns3server/templates/project.html @@ -27,6 +27,7 @@ in futur GNS3 versions. Name ID + Status Compute Console @@ -34,6 +35,7 @@ in futur GNS3 versions. {{node.name}} {{node.id}} + {{node.status}} {{node.compute.id}} Console diff --git a/tests/controller/test_project.py b/tests/controller/test_project.py index f528a29f..7f073be8 100644 --- a/tests/controller/test_project.py +++ b/tests/controller/test_project.py @@ -23,7 +23,7 @@ import pytest import aiohttp import zipfile from unittest.mock import MagicMock -from tests.utils import AsyncioMagicMock +from tests.utils import AsyncioMagicMock, asyncio_patch from unittest.mock import patch from uuid import uuid4 @@ -349,7 +349,9 @@ def test_dump(): def test_open_close(async_run, controller): project = Project(controller=controller, status="closed", name="Test") assert project.status == "closed" + project.start_all = AsyncioMagicMock() async_run(project.open()) + assert not project.start_all.called assert project.status == "opened" controller._notification = MagicMock() async_run(project.close()) @@ -357,6 +359,14 @@ def test_open_close(async_run, controller): controller.notification.emit.assert_any_call("project.closed", project.__json__()) +def test_open_auto_start(async_run, controller): + project = Project(controller=controller, status="closed", name="Test") + project.auto_start = True + project.start_all = AsyncioMagicMock() + async_run(project.open()) + assert project.start_all.called + + def test_is_running(project, async_run, node): """ If a node is started or paused return True @@ -451,3 +461,49 @@ def test_snapshot(project, async_run): # Raise a conflict if name is already use with pytest.raises(aiohttp.web_exceptions.HTTPConflict): snapshot = async_run(project.snapshot("test1")) + + +def test_start_all(project, async_run): + compute = MagicMock() + compute.id = "local" + response = MagicMock() + response.json = {"console": 2048} + compute.post = AsyncioMagicMock(return_value=response) + + for node_i in range(0, 10): + async_run(project.add_node(compute, "test", None, node_type="vpcs", properties={"startup_config": "test.cfg"})) + + compute.post = AsyncioMagicMock() + async_run(project.start_all()) + assert len(compute.post.call_args_list) == 10 + + +def test_stop_all(project, async_run): + compute = MagicMock() + compute.id = "local" + response = MagicMock() + response.json = {"console": 2048} + compute.post = AsyncioMagicMock(return_value=response) + + for node_i in range(0, 10): + async_run(project.add_node(compute, "test", None, node_type="vpcs", properties={"startup_config": "test.cfg"})) + + compute.post = AsyncioMagicMock() + async_run(project.stop_all()) + assert len(compute.post.call_args_list) == 10 + + +def test_suspend_all(project, async_run): + compute = MagicMock() + compute.id = "local" + response = MagicMock() + response.json = {"console": 2048} + compute.post = AsyncioMagicMock(return_value=response) + + for node_i in range(0, 10): + async_run(project.add_node(compute, "test", None, node_type="vpcs", properties={"startup_config": "test.cfg"})) + + compute.post = AsyncioMagicMock() + async_run(project.suspend_all()) + assert len(compute.post.call_args_list) == 10 + diff --git a/tests/controller/test_topology.py b/tests/controller/test_topology.py index db1f5af3..400adddb 100644 --- a/tests/controller/test_topology.py +++ b/tests/controller/test_topology.py @@ -35,6 +35,8 @@ def test_project_to_topology_empty(tmpdir): "project_id": project.id, "name": "Test", "auto_start": False, + "auto_close": False, + "auto_open": False, "revision": 5, "topology": { "nodes": [],