diff --git a/gns3server/handlers/api/project_handler.py b/gns3server/handlers/api/project_handler.py index 95370189..b28d67b5 100644 --- a/gns3server/handlers/api/project_handler.py +++ b/gns3server/handlers/api/project_handler.py @@ -21,7 +21,6 @@ import json import os import psutil import tempfile -import zipfile from ...web.route import Route from ...schemas.project import PROJECT_OBJECT_SCHEMA, PROJECT_CREATE_SCHEMA, PROJECT_UPDATE_SCHEMA, PROJECT_FILE_LIST_SCHEMA, PROJECT_LIST_SCHEMA @@ -58,6 +57,7 @@ class ProjectHandler: description="Create a new project on the server", status_codes={ 201: "Project created", + 403: "You are not allowed to modify this property", 409: "Project already created" }, output=PROJECT_OBJECT_SCHEMA, @@ -382,14 +382,16 @@ class ProjectHandler: "project_id": "The UUID of the project", }, raw=True, + output=PROJECT_OBJECT_SCHEMA, status_codes={ - 200: "Return the file" + 200: "Project imported", + 403: "You are not allowed to modify this property" }) def import_project(request, response): pm = ProjectManager.instance() project_id = request.match_info["project_id"] - project = pm.create_project(project_id=project_id) + project = pm.get_project(project_id) # We write the content to a temporary location # and after extract all. It could be more optimal to stream @@ -403,10 +405,9 @@ class ProjectHandler: if not packet: break temp.write(packet) - - with zipfile.ZipFile(temp) as myzip: - myzip.extractall(project.path) + project.import_zip(temp) except OSError as e: raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e)) + response.json(project) response.set_status(201) diff --git a/gns3server/modules/project.py b/gns3server/modules/project.py index 55f8ef1b..ecdc0a43 100644 --- a/gns3server/modules/project.py +++ b/gns3server/modules/project.py @@ -15,13 +15,14 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import aiohttp import os +import aiohttp import shutil import asyncio import hashlib import zipstream import zipfile +import json from uuid import UUID, uuid4 from .port_manager import PortManager @@ -171,6 +172,8 @@ class Project: @name.setter def name(self, name): + if "/" in name or "\\" in name: + raise aiohttp.web.HTTPForbidden(text="Name can not contain path separator") self._name = name @property @@ -540,3 +543,24 @@ class Project: else: z.write(path, os.path.relpath(path, self._path)) return z + + def import_zip(self, stream): + """ + Import a project contain in a zip file + + :params: A io.BytesIO of the zifile + """ + + with zipfile.ZipFile(stream) as myzip: + myzip.extractall(self.path) + project_file = os.path.join(self.path, "project.gns3") + if os.path.exists(project_file): + with open(project_file) as f: + topology = json.load(f) + topology["project_id"] = self.id + topology["name"] = self.name + + with open(project_file, "w") as f: + json.dump(topology, f, indent=4) + + shutil.move(project_file, os.path.join(self.path, self.name + ".gns3")) diff --git a/tests/handlers/api/test_project.py b/tests/handlers/api/test_project.py index 1d5ec763..cf952794 100644 --- a/tests/handlers/api/test_project.py +++ b/tests/handlers/api/test_project.py @@ -306,12 +306,12 @@ def test_export(server, tmpdir, loop, project): assert content == b"hello" -def test_import(server, tmpdir, loop): +def test_import(server, tmpdir, loop, project): with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: myzip.writestr("demo", b"hello") - project_id = str(uuid.uuid4()) + project_id = project.id with open(str(tmpdir / "test.zip"), "rb") as f: response = server.post("/projects/{project_id}/import".format(project_id=project_id), body=f.read(), raw=True) diff --git a/tests/modules/test_project.py b/tests/modules/test_project.py index 3ed80fa6..c61a12ae 100644 --- a/tests/modules/test_project.py +++ b/tests/modules/test_project.py @@ -17,6 +17,8 @@ # along with this program. If not, see . import os +import uuid +import json import asyncio import pytest import aiohttp @@ -293,3 +295,30 @@ def test_export(tmpdir): assert 'project.gns3' in myzip.namelist() assert 'project-files/snapshots/test' not in myzip.namelist() assert 'vm-1/dynamips/test_log.txt' not in myzip.namelist() + + +def test_import(tmpdir): + + project_id = str(uuid.uuid4()) + project = Project(name="test", project_id=project_id) + + with open(str(tmpdir / "project.gns3"), 'w+') as f: + f.write('{"project_id": "ddd", "name": "test"}') + with open(str(tmpdir / "b.png"), 'w+') as f: + f.write("B") + + zip_path = str(tmpdir / "project.zip") + with zipfile.ZipFile(zip_path, 'w') as myzip: + myzip.write(str(tmpdir / "project.gns3"), "project.gns3") + myzip.write(str(tmpdir / "b.png"), "b.png") + + with open(zip_path, "rb") as f: + project.import_zip(f) + + assert os.path.exists(os.path.join(project.path, "b.png")) + assert os.path.exists(os.path.join(project.path, "test.gns3")) + + with open(os.path.join(project.path, "test.gns3")) as f: + content = json.load(f) + assert content["project_id"] == project_id + assert content["name"] == project.name