Project duplication support.

pull/1537/head
grossmj 5 years ago
parent a8990c9e89
commit 52bfa636c1

@ -21,8 +21,10 @@ import json
import uuid
import copy
import shutil
import time
import asyncio
import aiohttp
import aiofiles
import tempfile
import zipfile
@ -949,15 +951,6 @@ class Project:
while self._loading:
await asyncio.sleep(0.5)
def _create_duplicate_project_file(self, path, zipstream):
"""
Creates the project file (to be run in its own thread)
"""
with open(path, "wb") as f:
for data in zipstream:
f.write(data)
async def duplicate(self, name=None, location=None):
"""
Duplicate a project
@ -977,13 +970,23 @@ class Project:
self.dump()
assert self._status != "closed"
try:
begin = time.time()
with tempfile.TemporaryDirectory() as tmpdir:
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
zipstream = await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True)
with aiozipstream.ZipFile(compression=zipfile.ZIP_DEFLATED) as zstream:
await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True)
# export the project to a temporary location
project_path = os.path.join(tmpdir, "project.gns3p")
await wait_run_in_executor(self._create_duplicate_project_file, project_path, zipstream)
with open(project_path, "rb") as f:
project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
log.info("Exporting project to '{}'".format(project_path))
async with aiofiles.open(project_path, 'wb') as f:
async for chunk in zstream:
await f.write(chunk)
# import the temporary project
with open(project_path, "rb") as f:
project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.id, time.time() - begin))
except (ValueError, OSError, UnicodeEncodeError) as e:
raise aiohttp.web.HTTPConflict(text="Cannot duplicate project: {}".format(str(e)))

@ -300,7 +300,7 @@ class ZipFile(zipfile.ZipFile):
if cmpr:
buf = await self._run_in_executor(cmpr.compress, buf)
compress_size = compress_size + len(buf)
await yield_(self.fp.write(buf))
await yield_(self.fp.write(buf))
if cmpr:
buf = cmpr.flush()

Loading…
Cancel
Save