From b126c396c9466c69a4dcac9558241ab81c635921 Mon Sep 17 00:00:00 2001 From: Julien Duponchelle Date: Tue, 24 May 2016 11:21:49 +0200 Subject: [PATCH] Start all, with a controlled concurrency Fix #536 --- .../handlers/api/controller/node_handler.py | 20 ++++-- gns3server/utils/asyncio/pool.py | 61 +++++++++++++++++++ 2 files changed, 76 insertions(+), 5 deletions(-) create mode 100644 gns3server/utils/asyncio/pool.py diff --git a/gns3server/handlers/api/controller/node_handler.py b/gns3server/handlers/api/controller/node_handler.py index ca27478e..c578d43c 100644 --- a/gns3server/handlers/api/controller/node_handler.py +++ b/gns3server/handlers/api/controller/node_handler.py @@ -17,6 +17,7 @@ from gns3server.web.route import Route from gns3server.controller import Controller +from gns3server.utils.asyncio.pool import Pool from gns3server.schemas.node import ( NODE_OBJECT_SCHEMA, @@ -103,8 +104,10 @@ class NodeHandler: def start_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) + pool = Pool(concurrency=3) for node in project.nodes.values(): - yield from node.start() + pool.append(node.start) + yield from pool.join() response.set_status(204) @Route.post( @@ -122,8 +125,10 @@ class NodeHandler: def stop_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) + pool = Pool(concurrency=3) for node in project.nodes.values(): - yield from node.stop() + pool.append(node.stop) + yield from pool.join() response.set_status(204) @Route.post( @@ -141,8 +146,10 @@ class NodeHandler: def suspend_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) + pool = Pool(concurrency=3) for node in project.nodes.values(): - yield from node.suspend() + pool.append(node.suspend) + yield from pool.join() response.set_status(204) @Route.post( @@ -160,10 +167,13 @@ class NodeHandler: def reload_all(request, response): project = Controller.instance().get_project(request.match_info["project_id"]) + pool = Pool(concurrency=3) for node in project.nodes.values(): - yield from node.stop() + pool.append(node.stop) + yield from pool.join() for node in project.nodes.values(): - yield from node.start() + pool.append(node.start) + yield from pool.join() response.set_status(204) @Route.post( diff --git a/gns3server/utils/asyncio/pool.py b/gns3server/utils/asyncio/pool.py new file mode 100644 index 00000000..b3ea0ea9 --- /dev/null +++ b/gns3server/utils/asyncio/pool.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# +# Copyright (C) 2016 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import copy +import asyncio + + +class Pool(): + """ + Limit concurrency for running parallel task + """ + def __init__(self, concurrency=2): + self._tasks = [] + self._concurrency = concurrency + + def append(self, task, *args, **kwargs): + self._tasks.append((copy.copy(task), args, kwargs)) + + @asyncio.coroutine + def join(self): + """ + Wait for all task to finish + """ + pending = set() + while len(self._tasks) > 0 or len(pending) > 0: + while len(self._tasks) > 0 and len(pending) < self._concurrency: + task, args, kwargs = self._tasks.pop(0) + pending.add(task(*args, **kwargs)) + (done, pending) = yield from asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + print(done) + + +def main(): + @asyncio.coroutine + def task(id): + print("Run", id) + yield from asyncio.sleep(0.5) + + pool = Pool(concurrency=5) + for i in range(1, 20): + pool.append(task, i) + loop = asyncio.get_event_loop() + loop.run_until_complete(pool.join()) + + +if __name__ == '__main__': + main()