From d606553e20d81f187a10511a8a557d3733f827f8 Mon Sep 17 00:00:00 2001 From: grossmj Date: Mon, 30 Aug 2021 16:53:41 +0930 Subject: [PATCH] Allow images to be stored in subdirs and used by templates. --- gns3server/api/routes/controller/images.py | 52 ++++++++------ gns3server/db/models/images.py | 4 +- gns3server/db/repositories/images.py | 23 +++++-- gns3server/db/repositories/templates.py | 15 +++-- gns3server/services/templates.py | 10 +-- gns3server/utils/images.py | 5 +- tests/api/routes/controller/test_images.py | 67 ++++++++++++++++--- tests/api/routes/controller/test_templates.py | 31 +++++++++ 8 files changed, 159 insertions(+), 48 deletions(-) diff --git a/gns3server/api/routes/controller/images.py b/gns3server/api/routes/controller/images.py index 73a6cd4a..46e0f4a1 100644 --- a/gns3server/api/routes/controller/images.py +++ b/gns3server/api/routes/controller/images.py @@ -23,6 +23,7 @@ import logging import urllib.parse from fastapi import APIRouter, Request, Response, Depends, status +from sqlalchemy.orm.exc import MultipleResultsFound from typing import List from gns3server import schemas @@ -53,9 +54,9 @@ async def get_images( return await images_repo.get_images() -@router.post("/upload/{image_name}", response_model=schemas.Image, status_code=status.HTTP_201_CREATED) +@router.post("/upload/{image_path:path}", response_model=schemas.Image, status_code=status.HTTP_201_CREATED) async def upload_image( - image_name: str, + image_path: str, request: Request, image_type: schemas.ImageType = schemas.ImageType.qemu, images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), @@ -64,19 +65,20 @@ async def upload_image( Upload an image. """ - image_name = urllib.parse.unquote(image_name) + image_path = urllib.parse.unquote(image_path) + image_dir, image_name = os.path.split(image_path) directory = default_images_directory(image_type) - path = os.path.abspath(os.path.join(directory, image_name)) - if os.path.commonprefix([directory, path]) != directory: - raise ControllerForbiddenError(f"Could not write image: {image_name}, '{path}' is forbidden") + full_path = os.path.abspath(os.path.join(directory, image_dir, image_name)) + if os.path.commonprefix([directory, full_path]) != directory: + raise ControllerForbiddenError(f"Could not write image, '{image_path}' is forbidden") - if await images_repo.get_image(image_name): - raise ControllerBadRequestError(f"Image '{image_name}' already exists") + if await images_repo.get_image(image_path): + raise ControllerBadRequestError(f"Image '{image_path}' already exists") try: - image = await write_image(image_name, image_type, path, request.stream(), images_repo) + image = await write_image(image_name, image_type, full_path, request.stream(), images_repo) except (OSError, InvalidImageError) as e: - raise ControllerError(f"Could not save {image_type} image '{image_name}': {e}") + raise ControllerError(f"Could not save {image_type} image '{image_path}': {e}") # TODO: automatically create template based on image checksum #from gns3server.controller import Controller @@ -86,45 +88,53 @@ async def upload_image( return image -@router.get("/{image_name}", response_model=schemas.Image) +@router.get("/{image_path:path}", response_model=schemas.Image) async def get_image( - image_name: str, + image_path: str, images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), ) -> schemas.Image: """ Return an image. """ - image = await images_repo.get_image(image_name) + image_path = urllib.parse.unquote(image_path) + image = await images_repo.get_image(image_path) if not image: - raise ControllerNotFoundError(f"Image '{image_name}' not found") + raise ControllerNotFoundError(f"Image '{image_path}' not found") return image -@router.delete("/{image_name}", status_code=status.HTTP_204_NO_CONTENT) +@router.delete("/{image_path:path}", status_code=status.HTTP_204_NO_CONTENT) async def delete_image( - image_name: str, + image_path: str, images_repo: ImagesRepository = Depends(get_repository(ImagesRepository)), ) -> None: """ Delete an image. """ - image = await images_repo.get_image(image_name) + image_path = urllib.parse.unquote(image_path) + + try: + image = await images_repo.get_image(image_path) + except MultipleResultsFound: + raise ControllerBadRequestError(f"Image '{image_path}' matches multiple images. " + f"Please include the relative path of the image") + if not image: - raise ControllerNotFoundError(f"Image '{image_name}' not found") + raise ControllerNotFoundError(f"Image '{image_path}' not found") if await images_repo.get_image_templates(image.id): - raise ControllerError(f"Image '{image_name}' is used by one or more templates") + raise ControllerError(f"Image '{image_path}' is used by one or more templates") try: os.remove(image.path) except OSError: log.warning(f"Could not delete image file {image.path}") - success = await images_repo.delete_image(image_name) + success = await images_repo.delete_image(image_path) if not success: - raise ControllerError(f"Image '{image_name}' could not be deleted") + raise ControllerError(f"Image '{image_path}' could not be deleted") @router.post("/prune", status_code=status.HTTP_204_NO_CONTENT) diff --git a/gns3server/db/models/images.py b/gns3server/db/models/images.py index aba30387..8e933f5f 100644 --- a/gns3server/db/models/images.py +++ b/gns3server/db/models/images.py @@ -34,10 +34,10 @@ class Image(BaseTable): __tablename__ = "images" id = Column(Integer, primary_key=True, autoincrement=True) - filename = Column(String, unique=True, index=True) + filename = Column(String) image_type = Column(String) image_size = Column(BigInteger) - path = Column(String) + path = Column(String, unique=True, index=True) checksum = Column(String) checksum_algorithm = Column(String) templates = relationship("Template", secondary=image_template_link, back_populates="images") diff --git a/gns3server/db/repositories/images.py b/gns3server/db/repositories/images.py index 365d558d..b13c5fc3 100644 --- a/gns3server/db/repositories/images.py +++ b/gns3server/db/repositories/images.py @@ -36,14 +36,19 @@ class ImagesRepository(BaseRepository): super().__init__(db_session) - async def get_image(self, image_name: str) -> Optional[models.Image]: + async def get_image(self, image_path: str) -> Optional[models.Image]: """ - Get an image by its name (filename). + Get an image by its path. """ - query = select(models.Image).where(models.Image.filename == image_name) + image_dir, image_name = os.path.split(image_path) + if image_dir: + query = select(models.Image).\ + where(models.Image.filename == image_name, models.Image.path.endswith(image_path)) + else: + query = select(models.Image).where(models.Image.filename == image_name) result = await self._db_session.execute(query) - return result.scalars().first() + return result.scalars().one_or_none() async def get_image_by_checksum(self, checksum: str) -> Optional[models.Image]: """ @@ -95,12 +100,18 @@ class ImagesRepository(BaseRepository): await self._db_session.refresh(db_image) return db_image - async def delete_image(self, image_name: str) -> bool: + async def delete_image(self, image_path: str) -> bool: """ Delete an image. """ - query = delete(models.Image).where(models.Image.filename == image_name) + image_dir, image_name = os.path.split(image_path) + if image_dir: + query = delete(models.Image).\ + where(models.Image.filename == image_name, models.Image.path.endswith(image_path)).\ + execution_options(synchronize_session=False) + else: + query = delete(models.Image).where(models.Image.filename == image_name) result = await self._db_session.execute(query) await self._db_session.commit() return result.rowcount > 0 diff --git a/gns3server/db/repositories/templates.py b/gns3server/db/repositories/templates.py index 27cc41b7..f24608ca 100644 --- a/gns3server/db/repositories/templates.py +++ b/gns3server/db/repositories/templates.py @@ -15,6 +15,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os + from uuid import UUID from typing import List, Union, Optional from sqlalchemy import select, delete @@ -102,14 +104,19 @@ class TemplatesRepository(BaseRepository): await self._db_session.refresh(db_template) return db_template - async def get_image(self, image_name: str) -> Optional[models.Image]: + async def get_image(self, image_path: str) -> Optional[models.Image]: """ - Get an image by its name (filename). + Get an image by its path. """ - query = select(models.Image).where(models.Image.filename == image_name) + image_dir, image_name = os.path.split(image_path) + if image_dir: + query = select(models.Image).\ + where(models.Image.filename == image_name, models.Image.path.endswith(image_path)) + else: + query = select(models.Image).where(models.Image.filename == image_name) result = await self._db_session.execute(query) - return result.scalars().first() + return result.scalars().one_or_none() async def add_image_to_template( self, diff --git a/gns3server/services/templates.py b/gns3server/services/templates.py index c796470b..aad23334 100644 --- a/gns3server/services/templates.py +++ b/gns3server/services/templates.py @@ -155,11 +155,11 @@ class TemplatesService: templates.append(jsonable_encoder(builtin_template)) return templates - async def _find_image(self, image_name): + async def _find_image(self, image_path: str): - image = await self._templates_repo.get_image(image_name) + image = await self._templates_repo.get_image(image_path) if not image or not os.path.exists(image.path): - raise ControllerNotFoundError(f"Image '{image_name}' could not be found") + raise ControllerNotFoundError(f"Image '{image_path}' could not be found") return image async def _find_images(self, template_type: str, settings: dict) -> List[models.Image]: @@ -228,9 +228,9 @@ class TemplatesService: raise ControllerNotFoundError(f"Template '{template_id}' not found") return template - async def _remove_image(self, template_id: UUID, image:str) -> None: + async def _remove_image(self, template_id: UUID, image_path:str) -> None: - image = await self._templates_repo.get_image(image) + image = await self._templates_repo.get_image(image_path) await self._templates_repo.remove_image_from_template(template_id, image) async def update_template(self, template_id: UUID, template_update: schemas.TemplateUpdate) -> dict: diff --git a/gns3server/utils/images.py b/gns3server/utils/images.py index c8e96ca8..93ba9ef6 100644 --- a/gns3server/utils/images.py +++ b/gns3server/utils/images.py @@ -273,8 +273,9 @@ async def write_image( checksum = checksum.hexdigest() duplicate_image = await images_repo.get_image_by_checksum(checksum) - if duplicate_image: - raise InvalidImageError(f"Image {duplicate_image.filename} with same checksum already exists") + if duplicate_image and os.path.dirname(duplicate_image.path) == os.path.dirname(path): + raise InvalidImageError(f"Image {duplicate_image.filename} with " + f"same checksum already exists in the same directory") except InvalidImageError: os.remove(tmp_path) raise diff --git a/tests/api/routes/controller/test_images.py b/tests/api/routes/controller/test_images.py index f81ee1e1..2a4a2230 100644 --- a/tests/api/routes/controller/test_images.py +++ b/tests/api/routes/controller/test_images.py @@ -134,7 +134,7 @@ class TestImageRoutes: image_checksum.update(image_data) response = await client.post( - app.url_path_for("upload_image", image_name=image_name), + app.url_path_for("upload_image", image_path=image_name), params={"image_type": image_type}, content=image_data) @@ -155,7 +155,7 @@ class TestImageRoutes: async def test_image_get(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: image_name = os.path.basename(qcow2_image) - response = await client.get(app.url_path_for("get_image", image_name=image_name)) + response = await client.get(app.url_path_for("get_image", image_path=image_name)) assert response.status_code == status.HTTP_200_OK assert response.json()["filename"] == image_name @@ -165,21 +165,21 @@ class TestImageRoutes: with open(qcow2_image, "rb") as f: image_data = f.read() response = await client.post( - app.url_path_for("upload_image", image_name=image_name), + app.url_path_for("upload_image", image_path=image_name), params={"image_type": "qemu"}, content=image_data) assert response.status_code == status.HTTP_400_BAD_REQUEST - async def test_image_delete(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None: + async def test_image_delete(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: image_name = os.path.basename(qcow2_image) - response = await client.delete(app.url_path_for("delete_image", image_name=image_name)) + response = await client.delete(app.url_path_for("delete_image", image_path=image_name)) assert response.status_code == status.HTTP_204_NO_CONTENT async def test_not_found_image(self, app: FastAPI, client: AsyncClient, qcow2_image: str) -> None: image_name = os.path.basename(qcow2_image) - response = await client.get(app.url_path_for("get_image", image_name=image_name)) + response = await client.get(app.url_path_for("get_image", image_path=image_name)) assert response.status_code == status.HTTP_404_NOT_FOUND async def test_image_deleted_on_disk(self, app: FastAPI, client: AsyncClient, images_dir: str, qcow2_image: str) -> None: @@ -188,15 +188,66 @@ class TestImageRoutes: with open(qcow2_image, "rb") as f: image_data = f.read() response = await client.post( - app.url_path_for("upload_image", image_name=image_name), + app.url_path_for("upload_image", image_path=image_name), params={"image_type": "qemu"}, content=image_data) assert response.status_code == status.HTTP_201_CREATED - response = await client.delete(app.url_path_for("delete_image", image_name=image_name)) + response = await client.delete(app.url_path_for("delete_image", image_path=image_name)) assert response.status_code == status.HTTP_204_NO_CONTENT assert not os.path.exists(os.path.join(images_dir, "QEMU", image_name)) + @pytest.mark.parametrize( + "subdir, expected_result", + ( + ("subdir", status.HTTP_201_CREATED), + ("subdir", status.HTTP_400_BAD_REQUEST), + ("subdir2", status.HTTP_201_CREATED), + ), + ) + async def test_upload_image_subdir( + self, + app: FastAPI, + client: AsyncClient, + images_dir: str, + qcow2_image: str, + subdir: str, + expected_result: int + ) -> None: + + image_name = os.path.basename(qcow2_image) + with open(qcow2_image, "rb") as f: + image_data = f.read() + image_path = os.path.join(subdir, image_name) + response = await client.post( + app.url_path_for("upload_image", image_path=image_path), + params={"image_type": "qemu"}, + content=image_data) + assert response.status_code == expected_result + + async def test_image_delete_multiple_match( + self, + app: FastAPI, + client: AsyncClient, + qcow2_image: str + ) -> None: + + image_name = os.path.basename(qcow2_image) + response = await client.delete(app.url_path_for("delete_image", image_path=image_name)) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + async def test_image_delete_with_subdir( + self, + app: FastAPI, + client: AsyncClient, + qcow2_image: str + ) -> None: + + image_name = os.path.basename(qcow2_image) + image_path = os.path.join("subdir", image_name) + response = await client.delete(app.url_path_for("delete_image", image_path=image_path)) + assert response.status_code == status.HTTP_204_NO_CONTENT + async def test_prune_images(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: response = await client.post(app.url_path_for("prune_images")) diff --git a/tests/api/routes/controller/test_templates.py b/tests/api/routes/controller/test_templates.py index e9b59ae0..d2423230 100644 --- a/tests/api/routes/controller/test_templates.py +++ b/tests/api/routes/controller/test_templates.py @@ -1191,6 +1191,37 @@ class TestImageAssociationWithTemplate: db_template = await templates_repo.get_template(uuid.UUID(template_id)) assert len(db_template.images) == 0 + async def test_template_create_with_image_in_subdir( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession, + tmpdir: str, + ) -> None: + + params = {"name": "Qemu template", + "compute_id": "local", + "platform": "i386", + "hda_disk_image": "subdir/image.qcow2", + "ram": 512, + "template_type": "qemu"} + + path = os.path.join(tmpdir, "subdir", "image.qcow2") + os.makedirs(os.path.dirname(path)) + with open(path, "wb+") as f: + f.write(b'\x42\x42\x42\x42') + images_repo = ImagesRepository(db_session) + await images_repo.add_image("image.qcow2", "qemu", 42, path, "e342eb86c1229b6c154367a5476969b5", "md5") + + response = await client.post(app.url_path_for("create_template"), json=params) + assert response.status_code == status.HTTP_201_CREATED + template_id = response.json()["template_id"] + + templates_repo = TemplatesRepository(db_session) + db_template = await templates_repo.get_template(template_id) + assert len(db_template.images) == 1 + assert db_template.images[0].path.endswith("subdir/image.qcow2") + async def test_template_create_with_non_existing_image(self, app: FastAPI, client: AsyncClient) -> None: params = {"name": "Qemu template",