1
0
mirror of https://github.com/GNS3/gns3-server synced 2024-11-24 17:28:08 +00:00

API for duplicate a project

Ref https://github.com/GNS3/gns3-gui/issues/995
This commit is contained in:
Julien Duponchelle 2016-07-25 14:47:37 +02:00
parent fb3b6b62f5
commit f357879186
No known key found for this signature in database
GPG Key ID: CE8B29639E07F5E8
9 changed files with 285 additions and 57 deletions

View File

@ -25,7 +25,7 @@ import zipstream
@asyncio.coroutine @asyncio.coroutine
def export_project(project, temporary_dir, include_images=False): def export_project(project, temporary_dir, include_images=False, keep_compute_id=False, allow_all_nodes=False):
""" """
Export the project as zip. It's a ZipStream object. Export the project as zip. It's a ZipStream object.
The file will be read chunk by chunk when you iterate on The file will be read chunk by chunk when you iterate on
@ -34,6 +34,8 @@ def export_project(project, temporary_dir, include_images=False):
It will ignore some files like snapshots and It will ignore some files like snapshots and
:param temporary_dir: A temporary dir where to store intermediate data :param temporary_dir: A temporary dir where to store intermediate data
:param keep_compute_id: If false replace all compute id by local it's the standard behavior for .gns3project to make them portable
:param allow_all_nodes: Allow all nodes type to be include in the zip even if not portable default False
:returns: ZipStream object :returns: ZipStream object
""" """
@ -46,7 +48,7 @@ def export_project(project, temporary_dir, include_images=False):
# First we process the .gns3 in order to be sure we don't have an error # First we process the .gns3 in order to be sure we don't have an error
for file in os.listdir(project._path): for file in os.listdir(project._path):
if file.endswith(".gns3"): if file.endswith(".gns3"):
_export_project_file(project, os.path.join(project._path, file), z, include_images) _export_project_file(project, os.path.join(project._path, file), z, include_images, keep_compute_id, allow_all_nodes)
for root, dirs, files in os.walk(project._path, topdown=True): for root, dirs, files in os.walk(project._path, topdown=True):
files = [f for f in files if not _filter_files(os.path.join(root, f))] files = [f for f in files if not _filter_files(os.path.join(root, f))]
@ -61,10 +63,10 @@ def export_project(project, temporary_dir, include_images=False):
log.warn(msg) log.warn(msg)
project.emit("log.warning", {"message": msg}) project.emit("log.warning", {"message": msg})
continue continue
if file.endswith(".gns3"): if file.endswith(".gns3"):
pass pass
else: else:
z.write(path, os.path.relpath(path, project._path), compress_type=zipfile.ZIP_DEFLATED) z.write(path, os.path.relpath(path, project._path), compress_type=zipfile.ZIP_DEFLATED)
for compute in project.computes: for compute in project.computes:
if compute.id != "local": if compute.id != "local":
@ -104,7 +106,7 @@ def _filter_files(path):
return False return False
def _export_project_file(project, path, z, include_images): def _export_project_file(project, path, z, include_images, keep_compute_id, allow_all_nodes):
""" """
Take a project file (.gns3) and patch it for the export Take a project file (.gns3) and patch it for the export
@ -118,22 +120,26 @@ def _export_project_file(project, path, z, include_images):
with open(path) as f: with open(path) as f:
topology = json.load(f) topology = json.load(f)
if "topology" in topology and "nodes" in topology["topology"]:
for node in topology["topology"]["nodes"]:
if node["node_type"] in ["virtualbox", "vmware", "cloud"]:
raise aiohttp.web.HTTPConflict(text="Topology with a {} could not be exported".format(node["node_type"]))
node["compute_id"] = "local" # To make project portable all node by default run on local
if "properties" in node and node["node_type"] != "Docker":
for prop, value in node["properties"].items():
if prop.endswith("image"):
node["properties"][prop] = os.path.basename(value)
if include_images is True:
images.add(value)
if "topology" in topology: if "topology" in topology:
topology["topology"]["computes"] = [] # Strip compute informations because could contain secret info like password if "nodes" in topology["topology"]:
for node in topology["topology"]["nodes"]:
if not allow_all_nodes and node["node_type"] in ["virtualbox", "vmware", "cloud"]:
raise aiohttp.web.HTTPConflict(text="Topology with a {} could not be exported".format(node["node_type"]))
if not keep_compute_id:
node["compute_id"] = "local" # To make project portable all node by default run on local
if "properties" in node and node["node_type"] != "Docker":
for prop, value in node["properties"].items():
if prop.endswith("image"):
if not keep_compute_id: # If we keep the original compute we can keep the image path
node["properties"][prop] = os.path.basename(value)
if include_images is True:
images.add(value)
if not keep_compute_id:
topology["topology"]["computes"] = [] # Strip compute informations because could contain secret info like password
for image in images: for image in images:
_export_images(project, image, z) _export_images(project, image, z)

View File

@ -34,7 +34,7 @@ Handle the import of project from a .gns3project
@asyncio.coroutine @asyncio.coroutine
def import_project(controller, project_id, stream, location=None, name=None): def import_project(controller, project_id, stream, location=None, name=None, keep_compute_id=False):
""" """
Import a project contain in a zip file Import a project contain in a zip file
@ -45,13 +45,9 @@ def import_project(controller, project_id, stream, location=None, name=None):
:param stream: A io.BytesIO of the zipfile :param stream: A io.BytesIO of the zipfile
:param location: Parent directory for the project if None put in the default directory :param location: Parent directory for the project if None put in the default directory
:param name: Wanted project name, generate one from the .gns3 if None :param name: Wanted project name, generate one from the .gns3 if None
:param keep_compute_id: If true do not touch the compute id
:returns: Project :returns: Project
""" """
if location:
projects_path = location
else:
projects_path = controller.projects_directory()
os.makedirs(projects_path, exist_ok=True)
with zipfile.ZipFile(stream) as myzip: with zipfile.ZipFile(stream) as myzip:
@ -65,31 +61,42 @@ def import_project(controller, project_id, stream, location=None, name=None):
except KeyError: except KeyError:
raise aiohttp.web.HTTPConflict(text="Can't import topology the .gns3 is corrupted or missing") raise aiohttp.web.HTTPConflict(text="Can't import topology the .gns3 is corrupted or missing")
path = os.path.join(projects_path, project_name) if location:
path = location
else:
projects_path = controller.projects_directory()
path = os.path.join(projects_path, project_name)
os.makedirs(path) os.makedirs(path)
myzip.extractall(path) myzip.extractall(path)
topology = load_topology(os.path.join(path, "project.gns3")) topology = load_topology(os.path.join(path, "project.gns3"))
topology["name"] = project_name topology["name"] = project_name
# For some VM type we move them to the GNS3 VM if it's not a Linux host # Modify the compute id of the node depending of compute capacity
if not sys.platform.startswith("linux"): if not keep_compute_id:
vm_created = False # For some VM type we move them to the GNS3 VM if it's not a Linux host
if not sys.platform.startswith("linux"):
for node in topology["topology"]["nodes"]:
if node["node_type"] in ("docker", "qemu", "iou"):
node["compute_id"] = "vm"
else:
for node in topology["topology"]["nodes"]:
node["compute_id"] = "local"
for node in topology["topology"]["nodes"]: compute_created = set()
if node["node_type"] in ("docker", "qemu", "iou"): for node in topology["topology"]["nodes"]:
node["compute_id"] = "vm"
# Project created on the remote GNS3 VM? if node["compute_id"] != "local":
if not vm_created: # Project created on the remote GNS3 VM?
compute = controller.get_compute("vm") if node["compute_id"] not in compute_created:
yield from compute.post("/projects", data={ compute = controller.get_compute(node["compute_id"])
"name": project_name, yield from compute.post("/projects", data={
"project_id": project_id, "name": project_name,
}) "project_id": project_id,
vm_created = True })
compute_created.add(node["compute_id"])
yield from _move_files_to_compute(compute, project_id, path, os.path.join("project-files", node["node_type"], node["node_id"])) yield from _move_files_to_compute(compute, project_id, path, os.path.join("project-files", node["node_type"], node["node_id"]))
# And we dump the updated.gns3 # And we dump the updated.gns3
dot_gns3_path = os.path.join(path, project_name + ".gns3") dot_gns3_path = os.path.join(path, project_name + ".gns3")
@ -111,12 +118,14 @@ def _move_files_to_compute(compute, project_id, directory, files_path):
""" """
Move the files to a remote compute Move the files to a remote compute
""" """
for (dirpath, dirnames, filenames) in os.walk(os.path.join(directory, files_path)): location = os.path.join(directory, files_path)
for filename in filenames: if os.path.exists(location):
path = os.path.join(dirpath, filename) for (dirpath, dirnames, filenames) in os.walk(location):
dst = os.path.relpath(path, directory) for filename in filenames:
yield from _upload_file(compute, project_id, path, dst) path = os.path.join(dirpath, filename)
shutil.rmtree(os.path.join(directory, files_path)) dst = os.path.relpath(path, directory)
yield from _upload_file(compute, project_id, path, dst)
shutil.rmtree(os.path.join(directory, files_path))
@asyncio.coroutine @asyncio.coroutine

View File

@ -17,9 +17,11 @@
import os import os
import json import json
import uuid
import shutil
import asyncio import asyncio
import aiohttp import aiohttp
import shutil import tempfile
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@ -29,12 +31,26 @@ from .topology import project_to_topology, load_topology
from .udp_link import UDPLink from .udp_link import UDPLink
from ..config import Config from ..config import Config
from ..utils.path import check_path_allowed, get_default_project_directory from ..utils.path import check_path_allowed, get_default_project_directory
from .export_project import export_project
from .import_project import import_project
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def open_required(func):
"""
Use this decorator to raise an error if the project is not opened
"""
def wrapper(self, *args, **kwargs):
if self._status == "closed":
raise aiohttp.web.HTTPForbidden(text="The project is not opened")
return func(self, *args, **kwargs)
return wrapper
class Project: class Project:
""" """
A project inside a controller A project inside a controller
@ -74,8 +90,13 @@ class Project:
self._filename = filename self._filename = filename
else: else:
self._filename = self.name + ".gns3" self._filename = self.name + ".gns3"
self.reset() self.reset()
# At project creation we write an empty .gns3
if not os.path.exists(self._topology_file()):
self.dump()
def reset(self): def reset(self):
""" """
Called when open/close a project. Cleanup internal stuff Called when open/close a project. Cleanup internal stuff
@ -212,6 +233,7 @@ class Project:
return self.update_allocated_node_name(new_name) return self.update_allocated_node_name(new_name)
return new_name return new_name
@open_required
@asyncio.coroutine @asyncio.coroutine
def add_node(self, compute, name, node_id, **kwargs): def add_node(self, compute, name, node_id, **kwargs):
""" """
@ -243,6 +265,7 @@ class Project:
return node return node
return self._nodes[node_id] return self._nodes[node_id]
@open_required
@asyncio.coroutine @asyncio.coroutine
def delete_node(self, node_id): def delete_node(self, node_id):
@ -258,6 +281,7 @@ class Project:
self.dump() self.dump()
self.controller.notification.emit("node.deleted", node.__json__()) self.controller.notification.emit("node.deleted", node.__json__())
@open_required
def get_node(self, node_id): def get_node(self, node_id):
""" """
Return the node or raise a 404 if the node is unknown Return the node or raise a 404 if the node is unknown
@ -281,6 +305,7 @@ class Project:
""" """
return self._drawings return self._drawings
@open_required
@asyncio.coroutine @asyncio.coroutine
def add_drawing(self, drawing_id=None, **kwargs): def add_drawing(self, drawing_id=None, **kwargs):
""" """
@ -296,6 +321,7 @@ class Project:
return drawing return drawing
return self._drawings[drawing_id] return self._drawings[drawing_id]
@open_required
def get_drawing(self, drawing_id): def get_drawing(self, drawing_id):
""" """
Return the Drawing or raise a 404 if the drawing is unknown Return the Drawing or raise a 404 if the drawing is unknown
@ -305,6 +331,7 @@ class Project:
except KeyError: except KeyError:
raise aiohttp.web.HTTPNotFound(text="Drawing ID {} doesn't exist".format(drawing_id)) raise aiohttp.web.HTTPNotFound(text="Drawing ID {} doesn't exist".format(drawing_id))
@open_required
@asyncio.coroutine @asyncio.coroutine
def delete_drawing(self, drawing_id): def delete_drawing(self, drawing_id):
drawing = self.get_drawing(drawing_id) drawing = self.get_drawing(drawing_id)
@ -312,6 +339,7 @@ class Project:
self.dump() self.dump()
self.controller.notification.emit("drawing.deleted", drawing.__json__()) self.controller.notification.emit("drawing.deleted", drawing.__json__())
@open_required
@asyncio.coroutine @asyncio.coroutine
def add_link(self, link_id=None): def add_link(self, link_id=None):
""" """
@ -324,6 +352,7 @@ class Project:
self.dump() self.dump()
return link return link
@open_required
@asyncio.coroutine @asyncio.coroutine
def delete_link(self, link_id): def delete_link(self, link_id):
link = self.get_link(link_id) link = self.get_link(link_id)
@ -332,6 +361,7 @@ class Project:
self.dump() self.dump()
self.controller.notification.emit("link.deleted", link.__json__()) self.controller.notification.emit("link.deleted", link.__json__())
@open_required
def get_link(self, link_id): def get_link(self, link_id):
""" """
Return the Link or raise a 404 if the link is unknown Return the Link or raise a 404 if the link is unknown
@ -371,6 +401,7 @@ class Project:
except OSError as e: except OSError as e:
log.warning(str(e)) log.warning(str(e))
@open_required
@asyncio.coroutine @asyncio.coroutine
def delete(self): def delete(self):
yield from self.close() yield from self.close()
@ -406,6 +437,8 @@ class Project:
return return
self.reset() self.reset()
self._status = "opened"
path = self._topology_file() path = self._topology_file()
if os.path.exists(path): if os.path.exists(path):
topology = load_topology(path)["topology"] topology = load_topology(path)["topology"]
@ -424,7 +457,29 @@ class Project:
for drawing_data in topology.get("drawings", []): for drawing_data in topology.get("drawings", []):
drawing = yield from self.add_drawing(**drawing_data) drawing = yield from self.add_drawing(**drawing_data)
self._status = "opened"
@open_required
@asyncio.coroutine
def duplicate(self, name=None, location=None):
"""
Duplicate a project
It's the save as feature of the 1.X. It's implemented on top of the
export / import features. It will generate a gns3p and reimport it.
It's a little slower but we have only one implementation to maintain.
:param name: Name of the new project. A new one will be generated in case of conflicts
:param location: Parent directory of the new project
"""
with tempfile.TemporaryDirectory() as tmpdir:
zipstream = yield from export_project(self, tmpdir, keep_compute_id=True, allow_all_nodes=True)
with open(os.path.join(tmpdir, "project.gns3p"), "wb+") as f:
for data in zipstream:
f.write(data)
with open(os.path.join(tmpdir, "project.gns3p"), "rb") as f:
project = yield from import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
return project
def is_running(self): def is_running(self):
""" """

View File

@ -236,7 +236,7 @@ class ProjectHandler:
project = controller.get_project(request.match_info["project_id"]) project = controller.get_project(request.match_info["project_id"])
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
datas = yield from export_project(project, tmp_dir, include_images=bool(request.GET.get("include_images", "0"))) datas = yield from export_project(project, tmp_dir, include_images=bool(request.get("include_images", "0")))
# We need to do that now because export could failed and raise an HTTP error # We need to do that now because export could failed and raise an HTTP error
# that why response start need to be the later possible # that why response start need to be the later possible
response.content_type = 'application/gns3project' response.content_type = 'application/gns3project'
@ -285,6 +285,38 @@ class ProjectHandler:
response.json(project) response.json(project)
response.set_status(201) response.set_status(201)
@Route.post(
r"/projects/{project_id}/duplicate",
description="Duplicate a project",
parameters={
"project_id": "Project UUID",
},
input=PROJECT_CREATE_SCHEMA,
output=PROJECT_OBJECT_SCHEMA,
status_codes={
201: "Project duplicate",
403: "The server is not the local server",
404: "The project doesn't exist"
})
def duplicate(request, response):
controller = Controller.instance()
project = controller.get_project(request.match_info["project_id"])
if request.json.get("path"):
config = Config.instance()
if config.get_section_config("Server").getboolean("local", False) is False:
response.set_status(403)
return
location = request.json.get("path")
else:
location = None
new_project = yield from project.duplicate(name=request.json.get("name"), location=location)
response.json(new_project)
response.set_status(201)
@Route.get( @Route.get(
r"/projects/{project_id}/files/{path:.+}", r"/projects/{project_id}/files/{path:.+}",
description="Get a file from a project. Beware you have warranty to be able to access only to file global to the project (for example README.txt)", description="Get a file from a project. Beware you have warranty to be able to access only to file global to the project (for example README.txt)",

View File

@ -200,6 +200,7 @@ def test_export_disallow_some_type(tmpdir, project, async_run):
with pytest.raises(aiohttp.web.HTTPConflict): with pytest.raises(aiohttp.web.HTTPConflict):
z = async_run(export_project(project, str(tmpdir))) z = async_run(export_project(project, str(tmpdir)))
z = async_run(export_project(project, str(tmpdir), allow_all_nodes=True))
def test_export_fix_path(tmpdir, project, async_run): def test_export_fix_path(tmpdir, project, async_run):
@ -271,3 +272,44 @@ def test_export_with_images(tmpdir, project, async_run):
with zipfile.ZipFile(str(tmpdir / 'zipfile.zip')) as myzip: with zipfile.ZipFile(str(tmpdir / 'zipfile.zip')) as myzip:
myzip.getinfo("images/IOS/test.image") myzip.getinfo("images/IOS/test.image")
def test_export_keep_compute_id(tmpdir, project, async_run):
"""
If we want to restore the same computes we could ask to keep them
in the file
"""
with open(os.path.join(project.path, "test.gns3"), 'w+') as f:
data = {
"topology": {
"computes": [
{
"compute_id": "6b7149c8-7d6e-4ca0-ab6b-daa8ab567be0",
"host": "127.0.0.1",
"name": "Remote 1",
"port": 8001,
"protocol": "http"
}
],
"nodes": [
{
"compute_id": "6b7149c8-7d6e-4ca0-ab6b-daa8ab567be0",
"node_type": "vpcs"
}
]
}
}
json.dump(data, f)
z = async_run(export_project(project, str(tmpdir), keep_compute_id=True))
with open(str(tmpdir / 'zipfile.zip'), 'wb') as f:
for data in z:
f.write(data)
with zipfile.ZipFile(str(tmpdir / 'zipfile.zip')) as myzip:
with myzip.open("project.gns3") as myfile:
topo = json.loads(myfile.read().decode())["topology"]
assert topo["nodes"][0]["compute_id"] == "6b7149c8-7d6e-4ca0-ab6b-daa8ab567be0"
assert len(topo["computes"]) == 1

View File

@ -131,7 +131,7 @@ def test_import_with_images(tmpdir, async_run, controller):
assert os.path.exists(path), path assert os.path.exists(path), path
def test_import_iou_non_linux(linux_platform, async_run, tmpdir, controller): def test_import_iou_linux(linux_platform, async_run, tmpdir, controller):
""" """
On non linux host IOU should be local On non linux host IOU should be local
""" """
@ -224,6 +224,49 @@ def test_import_iou_non_linux(windows_platform, async_run, tmpdir, controller):
assert topo["topology"]["nodes"][1]["compute_id"] == "local" assert topo["topology"]["nodes"][1]["compute_id"] == "local"
def test_import_keep_compute_id(windows_platform, async_run, tmpdir, controller):
"""
On linux host IOU should be moved to the GNS3 VM
"""
project_id = str(uuid.uuid4())
controller._computes["vm"] = AsyncioMagicMock()
topology = {
"project_id": str(uuid.uuid4()),
"name": "test",
"type": "topology",
"topology": {
"nodes": [
{
"compute_id": "local",
"node_id": "0fd3dd4d-dc93-4a04-a9b9-7396a9e22e8b",
"node_type": "iou",
"properties": {}
}
],
"links": [],
"computes": [],
"drawings": []
},
"revision": 5,
"version": "2.0.0"
}
with open(str(tmpdir / "project.gns3"), 'w+') as f:
json.dump(topology, f)
zip_path = str(tmpdir / "project.zip")
with zipfile.ZipFile(zip_path, 'w') as myzip:
myzip.write(str(tmpdir / "project.gns3"), "project.gns3")
with open(zip_path, "rb") as f:
project = async_run(import_project(controller, project_id, f, keep_compute_id=True))
with open(os.path.join(project.path, "test.gns3")) as f:
topo = json.load(f)
assert topo["topology"]["nodes"][0]["compute_id"] == "local"
def test_move_files_to_compute(tmpdir, async_run): def test_move_files_to_compute(tmpdir, async_run):
project_id = str(uuid.uuid4()) project_id = str(uuid.uuid4())
@ -261,11 +304,11 @@ def test_import_project_name_and_location(async_run, tmpdir, controller):
myzip.write(str(tmpdir / "project.gns3"), "project.gns3") myzip.write(str(tmpdir / "project.gns3"), "project.gns3")
with open(zip_path, "rb") as f: with open(zip_path, "rb") as f:
project = async_run(import_project(controller, project_id, f, name="hello", location=str(tmpdir / "test"))) project = async_run(import_project(controller, project_id, f, name="hello", location=str(tmpdir / "hello")))
assert project.name == "hello" assert project.name == "hello"
assert os.path.exists(str(tmpdir / "test" / "hello" / "hello.gns3")) assert os.path.exists(str(tmpdir / "hello" / "hello.gns3"))
# A new project name is generated when you import twice the same name # A new project name is generated when you import twice the same name
with open(zip_path, "rb") as f: with open(zip_path, "rb") as f:

View File

@ -203,7 +203,7 @@ def test_delete_node_delete_link(async_run, controller):
controller.notification.emit.assert_any_call("link.deleted", link.__json__()) controller.notification.emit.assert_any_call("link.deleted", link.__json__())
def test_getVM(async_run, controller): def test_get_node(async_run, controller):
compute = MagicMock() compute = MagicMock()
project = Project(controller=controller, name="Test") project = Project(controller=controller, name="Test")
@ -217,6 +217,11 @@ def test_getVM(async_run, controller):
with pytest.raises(aiohttp.web_exceptions.HTTPNotFound): with pytest.raises(aiohttp.web_exceptions.HTTPNotFound):
project.get_node("test") project.get_node("test")
# Raise an error if the project is not opened
async_run(project.close())
with pytest.raises(aiohttp.web.HTTPForbidden):
project.get_node(vm.id)
def test_addLink(async_run, project, controller): def test_addLink(async_run, project, controller):
compute = MagicMock() compute = MagicMock()
@ -339,3 +344,32 @@ def test_is_running(project, async_run, node):
assert project.is_running() is False assert project.is_running() is False
node._status = "started" node._status = "started"
assert project.is_running() is True assert project.is_running() is True
def test_duplicate(project, async_run, controller):
"""
Duplicate a project, the node should remain on the remote server
if they were on remote server
"""
compute = MagicMock()
compute.id = "remote"
compute.list_files = AsyncioMagicMock(return_value=[])
controller._computes["remote"] = compute
response = MagicMock()
response.json = {"console": 2048}
compute.post = AsyncioMagicMock(return_value=response)
remote_vpcs = async_run(project.add_node(compute, "test", None, node_type="vpcs", properties={"startup_config": "test.cfg"}))
# We allow node not allowed for standard import / export
remote_virtualbox = async_run(project.add_node(compute, "test", None, node_type="virtualbox", properties={"startup_config": "test.cfg"}))
new_project = async_run(project.duplicate(name="Hello"))
assert new_project.id != project.id
assert new_project.name == "Hello"
async_run(new_project.open())
assert new_project.get_node(remote_vpcs.id).compute.id == "remote"
assert new_project.get_node(remote_virtualbox.id).compute.id == "remote"

View File

@ -220,3 +220,10 @@ def test_import(http_controller, tmpdir, controller):
with open(os.path.join(project.path, "demo")) as f: with open(os.path.join(project.path, "demo")) as f:
content = f.read() content = f.read()
assert content == "hello" assert content == "hello"
def test_duplicate(http_controller, tmpdir, loop, project):
response = http_controller.post("/projects/{project_id}/duplicate".format(project_id=project.id), {"name": "hello"}, example=True)
assert response.status == 201
assert response.json["name"] == "hello"

View File

@ -74,7 +74,7 @@ class AsyncioMagicMock(unittest.mock.MagicMock):
""" """
:return_values: Array of return value at each call will return the next :return_values: Array of return value at each call will return the next
""" """
if return_value: if return_value is not None:
future = asyncio.Future() future = asyncio.Future()
future.set_result(return_value) future.set_result(return_value)
kwargs["return_value"] = future kwargs["return_value"] = future