From 149d086fd8a4f42efbbd812d781e9439c0daa95e Mon Sep 17 00:00:00 2001 From: grossmj Date: Tue, 5 Jul 2022 23:01:44 +0200 Subject: [PATCH] Reactivate project importation --- gns3server/api/routes/controller/projects.py | 14 ++---- tests/api/routes/controller/test_projects.py | 48 ++++++++++++++------ 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/gns3server/api/routes/controller/projects.py b/gns3server/api/routes/controller/projects.py index 8dca5985..f2c2855f 100644 --- a/gns3server/api/routes/controller/projects.py +++ b/gns3server/api/routes/controller/projects.py @@ -349,33 +349,25 @@ async def export_project( async def import_project( project_id: UUID, request: Request, - path: Optional[Path] = None, name: Optional[str] = None ) -> schemas.Project: """ Import a project from a portable archive. """ - #TODO: import project remotely - raise NotImplementedError() - controller = Controller.instance() - # We write the content to a temporary location and after we extract it all. + # We write the content to a temporary location and then we extract it all. # It could be more optimal to stream this but it is not implemented in Python. try: begin = time.time() - # use the parent directory or projects dir as a temporary working dir - if path: - working_dir = os.path.abspath(os.path.join(path, os.pardir)) - else: - working_dir = controller.projects_directory() + working_dir = controller.projects_directory() with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir: temp_project_path = os.path.join(tmpdir, "project.zip") async with aiofiles.open(temp_project_path, "wb") as f: async for chunk in request.stream(): await f.write(chunk) with open(temp_project_path, "rb") as f: - project = await import_controller_project(controller, str(project_id), f, location=path, name=name) + project = await import_controller_project(controller, str(project_id), f, name=name) log.info(f"Project '{project.name}' imported in {time.time() - begin:.4f} seconds") except OSError as e: diff --git a/tests/api/routes/controller/test_projects.py b/tests/api/routes/controller/test_projects.py index 0787a0aa..e347707a 100644 --- a/tests/api/routes/controller/test_projects.py +++ b/tests/api/routes/controller/test_projects.py @@ -432,21 +432,39 @@ async def test_write_and_get_file_with_leading_slashes_in_filename( assert response.status_code == status.HTTP_403_FORBIDDEN -# async def test_import(app: FastAPI, client: AsyncClient, tmpdir, controller: Controller) -> None: -# -# with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: -# myzip.writestr("project.gns3", b'{"project_id": "c6992992-ac72-47dc-833b-54aa334bcd05", "version": "2.0.0", "name": "test"}') -# myzip.writestr("demo", b"hello") -# -# project_id = str(uuid.uuid4()) -# with open(str(tmpdir / "test.zip"), "rb") as f: -# response = await client.post(app.url_path_for("import_project", project_id=project_id), content=f.read()) -# assert response.status_code == status.HTTP_201_CREATED -# -# project = controller.get_project(project_id) -# with open(os.path.join(project.path, "demo")) as f: -# content = f.read() -# assert content == "hello" +async def test_import(app: FastAPI, client: AsyncClient, tmpdir, controller: Controller) -> None: + + with zipfile_zstd.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: + myzip.writestr("project.gns3", b'{"project_id": "c6992992-ac72-47dc-833b-54aa334bcd05", "version": "2.0.0", "name": "test"}') + myzip.writestr("demo", b"hello") + + project_id = str(uuid.uuid4()) + with open(str(tmpdir / "test.zip"), "rb") as f: + response = await client.post(app.url_path_for("import_project", project_id=project_id), content=f.read()) + assert response.status_code == status.HTTP_201_CREATED + + project = controller.get_project(project_id) + with open(os.path.join(project.path, "demo")) as f: + content = f.read() + assert content == "hello" + + +async def test_import_with_project_name(app: FastAPI, client: AsyncClient, tmpdir, controller: Controller) -> None: + + with zipfile_zstd.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: + myzip.writestr("project.gns3", b'{"project_id": "c6992992-ac72-47dc-833b-54aa334bcd05", "version": "2.0.0", "name": "test"}') + myzip.writestr("demo", b"hello") + + project_id = str(uuid.uuid4()) + with open(str(tmpdir / "test.zip"), "rb") as f: + response = await client.post( + app.url_path_for("import_project", project_id=project_id), + content=f.read(), + params={"name": "my-imported-project-name"} + ) + assert response.status_code == status.HTTP_201_CREATED + project = controller.get_project(project_id) + assert project.name == "my-imported-project-name" async def test_duplicate(app: FastAPI, client: AsyncClient, project: Project) -> None: